Skip to content
Snippets Groups Projects
Commit 0827ccca authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.data.datamodule] Implemented typing, added more logging, implemented...

[ptbench.data.datamodule] Implemented typing, added more logging, implemented sampler-balancing depending on sample class prevalence, parallelized dataset caching
parent a69868bf
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
......@@ -12,8 +12,16 @@ import lightning
import torch
import torch.utils.data
import torchvision.transforms
import tqdm
from tqdm import tqdm
from .typing import (
DatabaseSplit,
DataLoader,
Dataset,
RawDataLoader,
Sample,
Transform,
)
logger = logging.getLogger(__name__)
......@@ -64,7 +72,7 @@ def _setup_dataloader_multiproc_parameters(
return retval
class _DelayedLoadingDataset(torch.utils.data.Dataset):
class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand.
This list mimics a pytorch Dataset, except raw data loading is done
......@@ -78,11 +86,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
loader
An object instance that can load samples and labels from storage.
transforms
A set of transforms that should be applied on-the-fly for this dataset,
......@@ -92,30 +97,30 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
loader: RawDataLoader,
transforms: typing.Sequence[Transform] = [],
):
self.split = split
self.raw_data_loader = raw_data_loader
# Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
self.loader = loader
self.transform = torchvision.transforms.Compose(transforms)
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.split]
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.raw_data_loader(self.split[key])
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.split[key])
return self.transform(tensor), metadata
def __len__(self):
return len(self.split)
def __iter__(self):
for x in range(len(self)):
yield self[x]
class _CachedDataset(torch.utils.data.Dataset):
class _CachedDataset(Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
......@@ -130,11 +135,14 @@ class _CachedDataset(torch.utils.data.Dataset):
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
loader
An object instance that can load samples and labels from storage.
parallel
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading.
transforms
A set of transforms that should be applied to the cached samples for
......@@ -145,43 +153,96 @@ class _CachedDataset(torch.utils.data.Dataset):
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
loader: RawDataLoader,
parallel: int = -1,
transforms: typing.Sequence[Transform] = [],
):
# Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
self.data = [raw_data_loader(k) for k in tqdm(split)]
self.transform = torchvision.transforms.Compose(transforms)
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
if parallel < 0:
self.data = [
loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
]
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...")
with multiprocessing.Pool(instances) as p:
self.data = list(
tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
)
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [k[1]["label"] for k in self.data]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.data[key]
return self.transform(tensor), metadata
def __len__(self):
return len(self.data)
def __iter__(self):
for x in range(len(self)):
yield self[x]
def get_sample_weights(
dataset: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Computes the (inverse) probabilities of samples based on their class.
def _make_balanced_random_sampler(
dataset: Dataset,
target: str = "label",
) -> torch.utils.data.WeightedRandomSampler:
"""Generates a pytorch sampler that samples according to class
probabilities.
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
This function assumes labels are stored on the second entry of sample and
are integers. If that is not the case, we just balance the overall dataset
w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`).
For single dataset (not a concatenated one) without labels this function is
the same as a no-op.
In this implementation, we balance **both** class and dataset-origin
probabilities, what you expect for a truly *equitable* random sampler.
Take this example for illustration:
* Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
* Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
So:
| Dataset | Target | Samples | Weight | Normalised weight |
+---------+--------+---------+--------+-------------------+
| 1 | 0 | 9 | 1/9 | 1/36 |
| 1 | 1 | 1 | 1/1 | 1/4 |
| 2 | 0 | 3 | 1/3 | 1/12 |
| 2 | 1 | 3 | 1/3 | 1/12 |
Legend:
* Weight: the weights computed by this method
* Normalised weight: the weight per sample used by the random sampler,
after normalising the weights by the sum of all weights in the
concatenated dataset, such that the sum of all normalized weights times
the number of samples is 1.
The properties of this algorithm are as follows:
1. The probability of picking a sample from any target is the same (0.5 in
this case). To verify this, notice that the probability of picking a
sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
3 times higher than those from Dataset 1. As there are 3 times less
samples in Dataset 2 with ``target=0``, this makes choosing samples from
Dataset 1 proportionally less likely.
3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
3 times lower than those from Dataset 1. As there are 3 times less
samples in Dataset 1 with ``target=1``, this makes choosing samples from
Dataset 2 proportionally less likely.
This function assumes targets are stored on a dictionary entry named
``target`` inside the metadata information for the :py:type:``Sample``, and
that its value is integer.
We then instantiate a pytorch sampler using the inverse probabilities (the
more samples of a class, the less likely it becomes to be sampled.
Parameters
......@@ -189,42 +250,83 @@ def get_sample_weights(
dataset
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported
:py:class:`torch.utils.data.ConcatDataset` are supported.
target
The name of a metadata key pointing to an integer property that allows
balancing the dataset.
Returns
-------
sample_weights
The weights for all the samples in the dataset given as input
sampler
A sampler, to be used in a dataloader equipped with the same dataset
used to calculate the relative sample weights.
Raises
------
RuntimeError
If requested to balance a dataset (single, not-concatenated) without an
existing target.
"""
retval = []
def _calculate_dataset_weights(
d: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Calculates the weights for one dataset."""
if "label" in d[0][1] and isinstance(d[0][1]["label"], int):
# there are labels
targets = [k[1]["label"] for k in d]
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
weight_per_sample = [weights[k[1]] for k in d]
return torch.tensor(weight_per_sample)
else:
# no labels, weight dataset samples only by count
# n.b.: only works if __len__(d) is implemented, we turn off typechecking
weight = 1.0 / len(d)
return torch.tensor(len(d) * [weight])
def _calculate_weights(targets: list[int]) -> list[float]:
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
return [weights[k] for k in targets]
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
retval.append(_calculate_dataset_weights(ds)) # type: ignore[arg-type]
# There are two possible cases: targets/no-targets
metadata_example = dataset.datasets[0][0][1]
if target in metadata_example and isinstance(
metadata_example[target], int
):
# there are integer targets, let's balance with those
logger.info(
f"Balancing sample selection probabilities **and** "
f"concatenated-datasets using metadata targets `{target}`"
)
targets = [
k
for ds in dataset.datasets
for k in typing.cast(Dataset, ds).labels()
]
weights = _calculate_weights(targets)
else:
logger.warning(
f"Balancing samples **and** concatenated-datasets "
f"WITHOUT metadata targets (`{target}` not available)"
)
weights = [
k
for ds in dataset.datasets
for k in len(typing.cast(typing.Sized, ds))
* [1.0 / len(typing.cast(typing.Sized, ds))]
]
# Concatenate sample weights from all the datasets
return torch.cat(retval)
pass
return _calculate_dataset_weights(dataset)
else:
metadata_example = dataset[0][1]
if target in metadata_example and isinstance(
metadata_example[target], int
):
logger.info(
f"Balancing samples from dataset using metadata "
f"targets `{target}`"
)
weights = _calculate_weights(dataset.labels())
else:
raise RuntimeError(
f"Cannot balance samples without metadata targets `{target}`"
)
return torch.utils.data.WeightedRandomSampler(
weights, len(weights), replacement=True
)
class CachingDataModule(lightning.LightningDataModule):
......@@ -253,10 +355,10 @@ class CachingDataModule(lightning.LightningDataModule):
database_split
A dictionary that contains string keys representing subset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to the ``raw_data_loader`` for loading
disk). These objects are passed to the ``sample_loader`` for loading
the sample data (and metadata) in memory. The objects represented may
be of any format (e.g. list, dictionary, etc), for as long as the
``raw_data_loader`` can properly handle it. To check the split and the
``sample_loader`` can properly handle it. To check the split and the
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
......@@ -265,12 +367,8 @@ class CachingDataModule(lightning.LightningDataModule):
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
Samples can be cached **after** raw data loading.
loader
An object instance that can load samples and labels from storage.
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
......@@ -280,25 +378,10 @@ class CachingDataModule(lightning.LightningDataModule):
this attribute to ``True``. It is typicall useful for relatively small
datasets.
train_sampler
If set, then reset the default random sampler from torch with a
(potentially) biased one of your choice. Notice this will not play an
effect on the batching strategy, only on the way the samples are picked
from the original dataset. A typical replacement is:
.. code:: python
sample_weights = get_sample_weights(dataset)
train_sampler = torch.utils.data.WeightedRandomSampler(
sample_weights, len(sample_weights), replacement=True
)
The variable ``sample_weights`` represents the probabilities (may not
sum to one) of picking each particular sample in the dataset for a
batch. We typically set ``replacement=True`` to avoid issues with
missing data from one of the classes in mini-batches. This is
particularly important in highly unbalanced datasets. The function
:py:func:`get_sample_weights` may help you in this aspect.
balance_sampler_by_class
If set, then modifies the random sampler used during training and
validation to balance sample picking probability, making sample
across classes **and** datasets equitable.
model_transforms
A list of transforms (torch modules) that will be applied after
......@@ -346,17 +429,15 @@ class CachingDataModule(lightning.LightningDataModule):
to enable that many multiprocessing instances for data loading.
"""
DatasetDictionary = dict[str, Dataset]
def __init__(
self,
database_split: dict[str, typing.Sequence[typing.Any]],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
database_split: DatabaseSplit,
raw_data_loader: RawDataLoader,
cache_samples: bool = False,
train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
model_transforms: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
balance_sampler_by_class: bool = False,
model_transforms: list[Transform] = [],
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
......@@ -369,7 +450,8 @@ class CachingDataModule(lightning.LightningDataModule):
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
self.train_sampler = train_sampler
self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class
self.model_transforms = model_transforms
self.drop_incomplete_batch = drop_incomplete_batch
......@@ -380,11 +462,18 @@ class CachingDataModule(lightning.LightningDataModule):
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
self._datasets: CachingDataModule.DatasetDictionary = {}
@property
def parallel(self) -> int:
"""The parallel property."""
"""Whether to use multiprocessing for data loading.
Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as
many data loading instances as processing cores as available in
the system. Set to >= 1 to enable that many multiprocessing
instances for data loading.
"""
return self._parallel
@parallel.setter
......@@ -394,7 +483,28 @@ class CachingDataModule(lightning.LightningDataModule):
value
)
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
self._datasets = {}
@property
def balance_sampler_by_class(self):
"""Whether to balance samples across labels/datasets.
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
"""
return self._train_sampler is not None
@balance_sampler_by_class.setter
def balance_sampler_by_class(self, value: bool):
if value:
if "train" not in self._datasets:
self._setup_dataset("train")
self._train_sampler = _make_balanced_random_sampler(
self._datasets["train"]
)
else:
self._train_sampler = None
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently sets the batch-chunk-size after validation.
......@@ -450,22 +560,39 @@ class CachingDataModule(lightning.LightningDataModule):
"""
if name in self._datasets:
logger.info(f"Dataset {name} is already setup. Not reloading it.")
logger.info(
f"Dataset `{name}` is already setup. "
f"Not re-instantiating it."
)
return
if self.cache_samples:
logger.info(f"Caching {name} dataset")
logger.info(
f"Loading dataset:`{name}` into memory (caching)."
f" Trade-off: CPU RAM: more | Disk: less"
)
self._datasets[name] = _CachedDataset(
self.database_split[name],
self.raw_data_loader,
self.parallel,
self.model_transforms,
)
else:
logger.info(
f"Loading dataset:`{name}` without caching."
f" Trade-off: CPU RAM: less | Disk: more"
)
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
]
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
......@@ -493,17 +620,12 @@ class CachingDataModule(lightning.LightningDataModule):
"""
if stage == "fit":
self._setup_dataset("train")
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
self._setup_dataset(k)
for k in ["train"] + self._val_dataset_keys():
self._setup_dataset(k)
elif stage == "validate":
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
self._setup_dataset(k)
for k in self._val_dataset_keys():
self._setup_dataset(k)
elif stage == "test":
self._setup_dataset("test")
......@@ -535,32 +657,33 @@ class CachingDataModule(lightning.LightningDataModule):
* ``predict``: uses only the test dataset
"""
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
self._datasets = {}
def train_dataloader(self) -> torch.utils.data.DataLoader:
def train_dataloader(self) -> DataLoader:
"""Returns the train data loader."""
return torch.utils.data.DataLoader(
self._datasets["train"],
shuffle=True,
shuffle=(self._train_sampler is None),
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=self.train_sampler,
sampler=self._train_sampler,
**self._dataloader_multiproc,
)
def available_dataset_keys(self) -> typing.KeysView[str]:
"""Returns all names for datasets that are setup."""
return self._datasets.keys()
def unshuffled_train_dataloader(self) -> DataLoader:
"""Returns the train data loader without shuffling."""
def val_database_split_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
]
return torch.utils.data.DataLoader(
self._datasets["train"],
shuffle=False,
batch_size=self._chunk_size,
drop_last=False,
**self._dataloader_multiproc,
)
def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]:
def val_dataloader(self) -> dict[str, DataLoader]:
"""Returns the validation data loader(s)"""
validation_loader_opts = {
......@@ -571,27 +694,28 @@ class CachingDataModule(lightning.LightningDataModule):
}
validation_loader_opts.update(self._dataloader_multiproc)
# select all keys of interest
return {
k: torch.utils.data.DataLoader(
self._datasets[k], **validation_loader_opts
)
for k in self.val_database_split_keys()
for k in self._val_dataset_keys()
}
def test_dataloader(self):
def test_dataloader(self) -> dict[str, DataLoader]:
"""Returns the test data loader(s)"""
return torch.utils.data.DataLoader(
self._datasets["test"],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
return dict(
test=torch.utils.data.DataLoader(
self._datasets["test"],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
)
def predict_dataloader(self):
def predict_dataloader(self) -> dict[str, DataLoader]:
"""Returns the prediction data loader(s)"""
return self.test_dataloader()
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import torch
import torch.utils.data
logger = logging.getLogger(__name__)
def _get_positive_weights(dataloader):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have
a balanced loss.
Parameters
----------
dataloader : :py:class:`torch.utils.data.DataLoader`
A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
Returns
-------
positive_weights : :py:class:`torch.Tensor`
the positive weight of each class in the dataset given as input
"""
targets = []
for batch in dataloader:
targets.extend(batch[1]["label"])
targets = torch.tensor(targets)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]]
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
......@@ -5,6 +5,8 @@
"""Data loading code."""
import pathlib
import numpy
import PIL.Image
......@@ -41,58 +43,58 @@ class RemoveBlackBorders:
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path):
def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data.
Parameters
----------
path : str
path
The full path leading to the image to be loaded
Returns
-------
image : PIL.Image.Image
image
A PIL image
"""
return PIL.Image.open(path)
def load_pil_baw(path):
def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data.
Parameters
----------
path : str
path
The full path leading to the image to be loaded
Returns
-------
image : PIL.Image.Image
image
A PIL image in grayscale mode
"""
return load_pil(path).convert("L")
def load_pil_rgb(path):
def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image:
"""Loads a sample data.
Parameters
----------
path : str
path
The full path leading to the image to be loaded
Returns
-------
image : PIL.Image.Image
image
A PIL image in RGB mode
"""
return load_pil(path).convert("RGB")
......@@ -4,19 +4,38 @@
"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
See :py:mod:`ptbench.data.shenzhen` for dataset details.
See :py:mod:`ptbench.data.shenzhen` for more database details.
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
* Raw data input (on disk):
* PNG images (black and white, encoded as color images)
* Variable width and height:
* widths: from 1130 to 3001 pixels
* heights: from 948 to 3001 pixels
* Output image:
* Transforms:
* Load raw PNG with :py:mod:`PIL`
* Remove black borders
* Torch resizing(512px, 512px)
* Torch center cropping (512px, 512px)
* Final specifications:
* Fixed resolution: 512x512 pixels
* Color RGB encoding
"""
import importlib.resources
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
from .loader import RawDataLoader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
......@@ -24,12 +43,5 @@ datamodule = CachingDataModule(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
# model_transforms = [],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
raw_data_loader=RawDataLoader(),
)
......@@ -21,47 +21,85 @@ Philips DR Digital Diagnose systems.
"""
import os
import typing
import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
from ..image_utils import RemoveBlackBorders, load_pil_baw
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
stored."""
_transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
"""Transforms that are always applied to the loaded raw images."""
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the Shenzen dataset.
Attributes
----------
def raw_data_loader(
sample: tuple[str, int]
) -> tuple[torch.Tensor, typing.Mapping]:
"""Loads a single image sample from the disk.
datadir
This variable contains the base directory where the database raw data
is stored.
Parameters
----------
transform
Transforms that are always applied to the loaded raw images.
"""
img_path
The path suffix, within the dataset root folder, where to find the
image to be loaded.
datadir: str
transform: torchvision.transforms.Compose
def __init__(self):
self.datadir = load_rc().get(
"datadir.shenzhen", os.path.realpath(os.curdir)
)
Returns
-------
self.transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
image
A PIL image in grayscale mode
"""
tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
def sample(self, sample: tuple[str, int]) -> Sample:
"""Loads a single image sample from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
sample
The sample representation
"""
tensor = self.transform(
load_pil_baw(os.path.join(self.datadir, sample[0]))
)
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Loads a single image sample label from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
label
The integer label associated with the sample
"""
return sample[1]
......@@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
......@@ -12,13 +11,12 @@ import typing
import torch
from .typing import DatabaseSplit, RawDataLoader
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
class JSONDatabaseSplit(DatabaseSplit):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
......@@ -73,7 +71,7 @@ class JSONDatabaseSplit(
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
......@@ -109,7 +107,7 @@ class JSONDatabaseSplit(
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
class CSVDatabaseSplit(DatabaseSplit):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
......@@ -154,9 +152,7 @@ class CSVDatabaseSplit(collections.abc.Mapping):
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
def _load_split_from_disk(self) -> DatabaseSplit:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
......@@ -171,7 +167,7 @@ class CSVDatabaseSplit(collections.abc.Mapping):
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
retval: DatabaseSplit = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
......@@ -204,8 +200,8 @@ class CSVDatabaseSplit(collections.abc.Mapping):
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
database_split: DatabaseSplit,
loader: RawDataLoader,
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
......@@ -224,9 +220,7 @@ def check_database_split_loading(
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
A loader object that knows how to handle full-samples or just labels.
limit
Maximum number of samples to check (in each split/subset
......@@ -248,7 +242,7 @@ def check_database_split_loading(
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
data, _ = loader.sample(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import typing
import torch
import torch.utils.data
Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
"""Definition of a sample.
First parameter
The actual data that is input to the model
Second parameter
A dictionary containing a named set of meta-data. One the most common is
the ``label`` entry.
"""
class RawDataLoader:
"""A loader object can load samples and labels from storage."""
def sample(self, _: typing.Any) -> Sample:
"""Loads whole samples from media."""
raise NotImplementedError("You must implement the `sample()` method")
def label(self, k: typing.Any) -> int:
"""Loads only sample label from media.
If you do not override this implementation, then, by default,
this method will call :py:meth:`sample` to load the whole sample
and extract the label.
"""
return self.sample(k)[1]["label"]
Transform = typing.Callable[[torch.Tensor], torch.Tensor]
"""A callable, that transforms tensors into (other) tensors.
Typically used in data-processing pipelines inside pytorch.
"""
TransformSequence = typing.Sequence[Transform]
"""A sequence of transforms."""
DatabaseSplit = dict[str, typing.Sequence[typing.Any]]
"""The definition of a database script.
A database script maps subset names to sequences of objects that,
through RawDataLoader's eventually become Samples in the processing
pipeline.
"""
class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
"""Our own definition of a pytorch Dataset, with interesting properties.
We iterate over Sample objects in this case. Our datasets always
provide a dunder len method.
"""
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method")
DataLoader = torch.utils.data.DataLoader[Sample]
"""Our own augmentation definition of a pytorch DataLoader.
We iterate over Sample objects in this case.
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment