Skip to content
Snippets Groups Projects
Commit 278a6198 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Merge branch 'add-datamodule-andre' into 'add-datamodule'

Reviewed DataModule design+docs+types

See merge request biosignal/software/ptbench!7
parents 6b6196a0 7eaac22c
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75780 failed
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
Showing
with 1449 additions and 1111 deletions
.. SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
.. coding=utf-8 ..
.. SPDX-License-Identifier: GPL-3.0-or-later
============ ============
References References
......
...@@ -15,33 +15,33 @@ dynamic = ["readme"] ...@@ -15,33 +15,33 @@ dynamic = ["readme"]
license = { text = "GNU General Public License v3 (GPLv3)" } license = { text = "GNU General Public License v3 (GPLv3)" }
authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }] authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }]
maintainers = [ maintainers = [
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" }, { name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" }, { name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
] ]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English", "Natural Language :: English",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
] ]
dependencies = [ dependencies = [
"clapper", "clapper",
"click", "click",
"numpy", "numpy",
"pandas", "pandas",
"scipy", "scipy",
"scikit-learn", "scikit-learn",
"tqdm", "tqdm",
"psutil", "psutil",
"tabulate", "tabulate",
"matplotlib", "matplotlib",
"pillow", "pillow",
"torch>=1.8", "torch>=1.8",
"torchvision>=0.10", "torchvision>=0.10",
"lightning>=2.0.3", "lightning>=2.0.3",
"tensorboard", "tensorboard",
] ]
[project.urls] [project.urls]
...@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases" ...@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases"
[project.optional-dependencies] [project.optional-dependencies]
qa = ["pre-commit"] qa = ["pre-commit"]
doc = [ doc = [
"sphinx", "sphinx",
"furo", "furo",
"sphinx-autodoc-typehints", "sphinx-autodoc-typehints",
"auto-intersphinx", "auto-intersphinx",
"sphinx-copybutton", "sphinx-copybutton",
"sphinx-inline-tabs", "sphinx-inline-tabs",
"sphinx-click", "sphinx-click",
] ]
test = ["pytest", "pytest-cov", "coverage"] test = ["pytest", "pytest-cov", "coverage"]
......
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config # optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1} optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer # optimizer
optimizer = "SGD" optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,20 +6,30 @@ ...@@ -6,20 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.0001}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Densenet( model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,20 +6,30 @@ ...@@ -6,20 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.01}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Densenet( model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
) )
...@@ -11,20 +11,16 @@ Screening and Visualization". ...@@ -11,20 +11,16 @@ Screening and Visualization".
Reference: [PASA-2019]_ Reference: [PASA-2019]_
""" """
from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.pasa import PASA
from ...data.transforms import ElasticDeformation
# config from ...models.pasa import Pasa
optimizer_configs = {"lr": 8e-5}
model = Pasa(
# optimizer train_loss=BCEWithLogitsLoss(),
optimizer = "Adam" validation_loss=BCEWithLogitsLoss(),
optimizer_type=Adam,
# criterion optimizer_arguments=dict(lr=8e-5),
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) augmentation_transforms=[ElasticDeformation(p=0.8)],
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) )
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import multiprocessing
import sys
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from .dataset import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
parallel=-1,
):
super().__init__()
self.batch_size = batch_size
self.batch_chunk_count = batch_chunk_count
self.train_dataset = None
self.validation_dataset = None
self.extra_validation_datasets = None
self._raw_data_transforms = []
self._model_transforms = []
self._augmentation_transforms = []
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.parallel = parallel
def setup(self, stage: str):
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
multiproc_kwargs = self._setup_multiproc(self.parallel)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
multiproc_kwargs = self._setup_multiproc(self.parallel)
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
def _setup_multiproc(self, parallel):
multiproc_kwargs = dict()
if parallel < 0:
multiproc_kwargs["num_workers"] = 0
else:
multiproc_kwargs["num_workers"] = (
parallel or multiprocessing.cpu_count()
)
if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
multiproc_kwargs[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
return multiproc_kwargs
def _build_transforms(self, is_train):
all_transforms = self.raw_data_transforms + self.model_transforms
if is_train:
all_transforms = all_transforms + self.augmentation_transforms
# all_transforms.append(transforms.ToTensor())
return transforms.Compose(all_transforms)
@property
def raw_data_transforms(self):
return self._raw_data_transforms
@raw_data_transforms.setter
def raw_data_transforms(self, transforms):
self._raw_data_transforms = transforms
@property
def model_transforms(self):
return self._model_transforms
@model_transforms.setter
def model_transforms(self, transforms):
self._model_transforms = transforms
@property
def augmentation_transforms(self):
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, transforms):
self._augmentation_transforms = transforms
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import json
import logging
import os
import pathlib
import torch
from tqdm import tqdm
logger = logging.getLogger(__name__)
class JSONProtocol:
"""Generic multi-protocol/subset filelist dataset that yields samples.
To create a new dataset, you need to provide one or more JSON formatted
filelists (one per protocol) with the following contents:
.. code-block:: json
{
"subset1": [
[
"value1",
"value2",
"value3"
],
[
"value4",
"value5",
"value6"
]
],
"subset2": [
]
}
Your dataset many contain any number of subsets, but all sample entries
must contain the same number of fields.
Parameters
----------
protocols : list, dict
Paths to one or more JSON formatted files containing the various
protocols to be recognized by this dataset, or a dictionary, mapping
protocol names to paths (or opened file objects) of CSV files.
Internally, we save a dictionary where keys default to the basename of
paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each entry in
the JSON file. It should have as many items as fields in each entry of
the JSON file.
loader : object
A function that receives as input, a context dictionary (with at least
a "protocol" and "subset" keys indicating which protocol and subset are
being served), and a dictionary with ``{fieldname: value}`` entries,
and returns an object with at least 2 attributes:
* ``key``: which must be a unique string for every sample across
subsets in a protocol, and
* ``data``: which contains the data associated witht this sample
"""
def __init__(self, protocols, fieldnames):
if isinstance(protocols, dict):
self._protocols = protocols
else:
self._protocols = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in protocols
}
self.fieldnames = fieldnames
def check(self, limit=0):
"""For each protocol, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for proto in self._protocols:
logger.info(f"Checking protocol '{proto}'...")
for name, samples in self.subsets(proto).items():
logger.info(f"Checking subset '{name}'...")
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"of protocol {proto} from file "
f"'{self._protocols[proto]}': {e}"
)
errors += 1
return errors
def subsets(self, protocol):
"""Returns all subsets in a protocol.
This method will load JSON information for a given protocol and return
all subsets of the given protocol after converting each entry through
the loader function.
Parameters
----------
protocol : str
Name of the protocol data to load
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
fileobj = self._protocols[protocol]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
if str(fileobj).endswith(".bz2"):
import bz2
with bz2.open(self._protocols[protocol]) as f:
data = json.load(f)
else:
with open(self._protocols[protocol]) as f:
data = json.load(f)
else:
data = json.load(fileobj)
fileobj.seek(0)
retval = {}
for subset, samples in data.items():
logger.info(f"Loading subset {subset} samples.")
retval[subset] = [
dict(zip(self.fieldnames, k))
for n, k in enumerate(tqdm(samples))
]
return retval
class CSVDataset:
"""Generic multi-subset filelist dataset that yields samples.
To create a new dataset, you only need to provide a CSV formatted filelist
using any separator (e.g. comma, space, semi-colon) with the following
information:
.. code-block:: text
value1,value2,value3
value4,value5,value6
...
Notice that all rows must have the same number of entries.
Parameters
----------
subsets : list, dict
Paths to one or more CSV formatted files containing the various subsets
to be recognized by this dataset, or a dictionary, mapping subset names
to paths (or opened file objects) of CSV files. Internally, we save a
dictionary where keys default to the basename of paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each column in
the CSV file. It should have as many items as fields in each row of
the CSV file(s).
loader : object
A function that receives as input, a context dictionary (with, at
least, a "subset" key indicating which subset is being served), and a
dictionary with ``{key: path}`` entries, and returns a dictionary with
the loaded data.
"""
def __init__(self, subsets, fieldnames, loader):
if isinstance(subsets, dict):
self._subsets = subsets
else:
self._subsets = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in subsets
}
self.fieldnames = fieldnames
self._loader = loader
def check(self, limit=0):
"""For each subset, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for name in self._subsets.keys():
logger.info(f"Checking subset '{name}'...")
samples = self.samples(name)
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"from file '{self._subsets[name]}': {e}"
)
errors += 1
return errors
def subsets(self):
"""Returns all available subsets at once.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
return {k: self.samples(k) for k in self._subsets.keys()}
def samples(self, subset):
"""Returns all samples in a subset.
This method will load CSV information for a given subset and return
all samples of the given subset after passing each entry through the
loading function.
Parameters
----------
subset : str
Name of the subset data to load
Returns
-------
subset : list
A lists of objects (respecting the ``key``, ``data`` interface).
"""
fileobj = self._subsets[subset]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self._subsets[subset], newline="") as f:
cf = csv.reader(f)
samples = [k for k in cf]
else:
cf = csv.reader(fileobj)
samples = [k for k in cf]
fileobj.seek(0)
return [
self._loader(
dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
)
for n, k in enumerate(samples)
]
class CachedDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files, used during prediction
for s in self._samples:
s["name"] = s["data"]
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.raw_data_loader(sample["data"])
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(sample["data"])
return sample
def __len__(self):
return len(self._samples)
class RuntimeDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0]["label"], int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0]["label"], int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the positive weights of each class to use them to have
a balanced loss.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
positive_weights : :py:class:`torch.Tensor`
the positive weight of each class in the dataset given as input
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]]
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
def normalize_data(normalization, model, datamodule):
from torch.utils.data import DataLoader
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
# Create z-normalization model layer if needed
if normalization == "imagenet":
model.normalizer.set_mean_std(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)
logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current":
# Compute mean/std of current train subset
temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3])
std = data[1].std(dim=[0, 2, 3])
model.normalizer.set_mean_std(mean, std)
# Format mean and std for logging
mean = str(
[
round(x, 3)
for x in ((mean * 10**3).round() / (10**3)).tolist()
]
)
std = str(
[
round(x, 3)
for x in ((std * 10**3).round() / (10**3)).tolist()
]
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
...@@ -5,61 +5,96 @@ ...@@ -5,61 +5,96 @@
"""Data loading code.""" """Data loading code."""
import pathlib
import numpy
import PIL.Image import PIL.Image
def load_pil(path): class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data. """Loads a sample data.
Parameters Parameters
---------- ----------
path : str path
The full path leading to the image to be loaded The full path leading to the image to be loaded
Returns Returns
------- -------
image : PIL.Image.Image image
A PIL image A PIL image
""" """
return PIL.Image.open(path) return PIL.Image.open(path)
def load_pil_baw(path): def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data. """Loads a sample data.
Parameters Parameters
---------- ----------
path : str path
The full path leading to the image to be loaded The full path leading to the image to be loaded
Returns Returns
------- -------
image : PIL.Image.Image image
A PIL image in grayscale mode A PIL image in grayscale mode
""" """
return load_pil(path).convert("L") return load_pil(path).convert("L")
def load_pil_rgb(path): def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data. """Loads a sample data.
Parameters Parameters
---------- ----------
path : str path
The full path leading to the image to be loaded The full path leading to the image to be loaded
Returns Returns
------- -------
image : PIL.Image.Image image
A PIL image in RGB mode A PIL image in RGB mode
""" """
return load_pil(path).convert("RGB") return load_pil(path).convert("RGB")
...@@ -264,7 +264,7 @@ json_dataset = JSONDataset( ...@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=True): def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
import torchvision.transforms as transforms import torchvision.transforms as transforms
from ..transforms import SingleAutoLevel16to8 from ..loader import SingleAutoLevel16to8
post_transforms = [] post_transforms = []
if not RGB: if not RGB:
......
...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems. ...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_ * Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less * Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none * Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels) * Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels)
""" """
import importlib.resources import importlib.resources
import os
from clapper.logging import setup
from ...utils.rc import load_rc
from ..loader import load_pil_baw
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [ _protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"), importlib.resources.files(__name__).joinpath("default.json.bz2"),
...@@ -43,11 +32,3 @@ _protocols = [ ...@@ -43,11 +32,3 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(img_path):
raw_data = load_pil_baw(os.path.join(_datadir, img_path))
return raw_data
...@@ -2,27 +2,46 @@ ...@@ -2,27 +2,46 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for See :py:mod:`ptbench.data.shenzhen` for more database details.
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default) This configuration:
* See :py:mod:`ptbench.data.shenzhen` for dataset details
""" * Raw data input (on disk):
* PNG images (black and white, encoded as color images)
* Variable width and height:
from clapper.logging import setup * widths: from 1130 to 3001 pixels
* heights: from 948 to 3001 pixels
from ..transforms import ElasticDeformation * Output image:
from .utils import ShenzhenDataModule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * Transforms:
* Load raw PNG with :py:mod:`PIL`
* Remove black borders
* Torch resizing(512px, 512px)
* Torch center cropping (512px, 512px)
* Final specifications:
* Fixed resolution: 512x512 pixels
* Color RGB encoding
"""
protocol_name = "default" import importlib.resources
augmentation_transforms = [ElasticDeformation(p=0.8)] from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .loader import RawDataLoader
datamodule = ShenzhenDataModule( datamodule = CachingDataModule(
protocol="default", database_split=JSONDatabaseSplit(
model_transforms=[], importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
augmentation_transforms=augmentation_transforms, "default.json.bz2"
)
),
raw_data_loader=RawDataLoader(),
) )
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import os
import torchvision.transforms
from ...utils.rc import load_rc
from ..image_utils import RemoveBlackBorders, load_pil_baw
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the Shenzen dataset.
Attributes
----------
datadir
This variable contains the base directory where the database raw data
is stored.
transform
Transforms that are always applied to the loaded raw images.
"""
datadir: str
transform: torchvision.transforms.Compose
def __init__(self):
self.datadir = load_rc().get(
"datadir.shenzhen", os.path.realpath(os.curdir)
)
self.transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
def sample(self, sample: tuple[str, int]) -> Sample:
"""Loads a single image sample from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
sample
The sample representation
"""
tensor = self.transform(
load_pil_baw(os.path.join(self.datadir, sample[0]))
)
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Loads a single image sample label from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
label
The integer label associated with the sample
"""
return sample[1]
...@@ -2,81 +2,40 @@ ...@@ -2,81 +2,40 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" See :py:mod:`ptbench.data.shenzhen` for dataset details.
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONProtocol
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
class DefaultModule(BaseDataModule): * output image resolution: 512x512 pixels
def __init__( """
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
data_transforms=[],
model_transforms=[],
train_transforms=[],
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.has_setup_fit = False
self.data_transforms = data_transforms import importlib.resources
self.model_transforms = model_transforms
self.train_transforms = train_transforms
"""[ from torchvision import transforms
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]"""
def setup(self, stage: str): from ..datamodule import CachingDataModule
if self.cache_samples: from ..split import JSONDatabaseSplit
logger.info( from .raw_data_loader import raw_data_loader
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
self.json_protocol = JSONProtocol( datamodule = CachingDataModule(
protocols=_protocols, database_split=JSONDatabaseSplit(
fieldnames=("data", "label"), importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
loader=samples_loader, "default.json.bz2"
post_transforms=self.post_transforms,
) )
),
if not self.has_setup_fit and stage == "fit": raw_data_loader=raw_data_loader,
( cache_samples=False,
self.train_dataset, # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
self.validation_dataset, model_transforms=[
self.extra_validation_datasets, transforms.ToPILImage(),
) = return_subsets(self.json_protocol, "default", stage) transforms.Lambda(lambda x: x.convert("RGB")),
self.has_setup_fit = True transforms.ToTensor(),
],
# batch_size = 1,
datamodule = DefaultModule # batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
from .typing import DatabaseSplit, RawDataLoader
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(DatabaseSplit):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(DatabaseSplit):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval: DatabaseSplit = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: DatabaseSplit,
loader: RawDataLoader,
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A loader object that knows how to handle full-samples or just labels.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data, _ = loader.sample(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
...@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates ...@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms from torchvision import transforms
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the
image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
class ElasticDeformation: class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
...@@ -68,7 +35,7 @@ class ElasticDeformation: ...@@ -68,7 +35,7 @@ class ElasticDeformation:
spline_order=1, spline_order=1,
mode="nearest", mode="nearest",
random_state=numpy.random, random_state=numpy.random,
p=1, p=1.0,
): ):
self.alpha = alpha self.alpha = alpha
self.sigma = sigma self.sigma = sigma
...@@ -79,13 +46,15 @@ class ElasticDeformation: ...@@ -79,13 +46,15 @@ class ElasticDeformation:
def __call__(self, img): def __call__(self, img):
if random.random() < self.p: if random.random() < self.p:
img = transforms.ToPILImage()(img) assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img) img = numpy.asarray(img)
assert img.ndim == 2 shape = img.shape[:2]
shape = img.shape
dx = ( dx = (
gaussian_filter( gaussian_filter(
...@@ -114,9 +83,22 @@ class ElasticDeformation: ...@@ -114,9 +83,22 @@ class ElasticDeformation:
numpy.reshape(y + dy, (-1, 1)), numpy.reshape(y + dy, (-1, 1)),
] ]
result = numpy.empty_like(img) result = numpy.empty_like(img)
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode if img.ndim == 2:
).reshape(shape) result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result)) return transforms.ToTensor()(PIL.Image.fromarray(result))
else: else:
return img return img
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import typing
import torch
import torch.utils.data
Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
"""Definition of a sample.
First parameter
The actual data that is input to the model
Second parameter
A dictionary containing a named set of meta-data. One the most common is
the ``label`` entry.
"""
class RawDataLoader:
"""A loader object can load samples and labels from storage."""
def sample(self, _: typing.Any) -> Sample:
"""Loads whole samples from media."""
raise NotImplementedError("You must implement the `sample()` method")
def label(self, k: typing.Any) -> int:
"""Loads only sample label from media.
If you do not override this implementation, then, by default,
this method will call :py:meth:`sample` to load the whole sample
and extract the label.
"""
return self.sample(k)[1]["label"]
Transform = typing.Callable[[torch.Tensor], torch.Tensor]
"""A callable, that transforms tensors into (other) tensors.
Typically used in data-processing pipelines inside pytorch.
"""
TransformSequence = typing.Sequence[Transform]
"""A sequence of transforms."""
DatabaseSplit = dict[str, typing.Sequence[typing.Any]]
"""The definition of a database script.
A database script maps subset names to sequences of objects that,
through RawDataLoader's eventually become Samples in the processing
pipeline.
"""
class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
"""Our own definition of a pytorch Dataset, with interesting properties.
We iterate over Sample objects in this case. Our datasets always
provide a dunder len method.
"""
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method")
DataLoader = torch.utils.data.DataLoader[Sample]
"""Our own augmentation definition of a pytorch DataLoader.
We iterate over Sample objects in this case.
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment