Skip to content
Snippets Groups Projects
Commit 4a2375d7 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[dataloader] Rename label to target

parent 800081d1
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 301 additions and 83 deletions
......@@ -51,7 +51,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -67,22 +67,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
int
The integer label associated with the sample.
The integer target associated with the sample.
"""
return sample[1]
......
......@@ -57,7 +57,7 @@ class RawDataLoader(_BaseRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -75,22 +75,22 @@ class RawDataLoader(_BaseRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
int
The integer label associated with the sample.
The integer target associated with the sample.
"""
return sample[1]
......
......@@ -74,7 +74,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -99,22 +99,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, list[int]]) -> list[int]:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, list[int]]) -> list[int]:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
sample target.
Returns
-------
list[int]
The integer labels associated with the sample.
The integer targets associated with the sample.
"""
return sample[1]
......
......@@ -52,7 +52,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -70,22 +70,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int | list[int]]) -> int | list[int]:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, int | list[int]]) -> int | list[int]:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
list[int]
The integer labels associated with the sample.
The integer targets associated with the sample.
"""
return sample[1]
......
......@@ -57,7 +57,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -75,22 +75,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
int
The integer label associated with the sample.
The integer target associated with the sample.
"""
return sample[1]
......
......@@ -47,7 +47,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
......@@ -65,21 +65,21 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample label.
the sample target.
Returns
-------
The integer label associated with the sample
The integer target associated with the sample
"""
return sample[1]
......
......@@ -199,7 +199,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by
sample target, and possible radiological findings represented by
bounding boxes.
Returns
......@@ -222,26 +222,26 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
)
return tensor, dict(
label=sample[1],
name=sample[0],
target=sample[1],
bounding_boxes=self.bounding_boxes(sample),
)
def label(self, sample: DatabaseSample) -> int:
"""Load a single image sample label from the disk.
def target(self, sample: DatabaseSample) -> int:
"""Load a single image sample target from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by
sample target, and possible radiological findings represented by
bounding boxes.
Returns
-------
int
The integer label associated with the sample.
The integer target associated with the sample.
"""
return sample[1]
......@@ -254,7 +254,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by
sample target, and possible radiological findings represented by
bounding boxes.
Returns
......
......@@ -13,6 +13,7 @@ import torchvision.models as models
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
......@@ -129,7 +130,7 @@ class Alexnet(Model):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -143,7 +144,7 @@ class Alexnet(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......
......@@ -13,6 +13,7 @@ import torchvision.models as models
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
......@@ -136,7 +137,7 @@ class Densenet(Model):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -150,7 +151,7 @@ class Densenet(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import torch
import torch.utils.data
from mednet.libs.common.data.typing import DataLoader
logger = logging.getLogger("mednet")
def _get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
torch.Tensor
The positive weight of each class in the dataset given as input.
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["target"]],
)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Return a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples
available.
Parameters
----------
dataloader
The DataLoader to use to compute the BCE weights.
Returns
-------
torch.nn.BCEWithLogitsLoss
An instance of the weighted loss.
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
......@@ -13,6 +13,7 @@ import torch.utils.data
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import Grayscale, SquareCenterPad
......@@ -202,7 +203,7 @@ class Pasa(Model):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -216,7 +217,7 @@ class Pasa(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
labels = batch[1]["target"]
# Increase label dimension if too low
# Allows single and multiclass usage
......
......@@ -28,7 +28,7 @@ def _as_predictions(
A list of typed predictions that can be saved to disk.
"""
return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
return [(v[1]["name"], v[1]["target"].item(), v[0].item()) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
......
......@@ -112,7 +112,7 @@ class _DelayedLoadingDataset(Dataset):
An iterable containing the raw dataset samples representing one of the
database split datasets.
loader
An object instance that can load samples and labels from storage.
An object instance that can load samples from storage.
transforms
A set of transforms that should be applied on-the-fly for this dataset,
to fit the output of the raw-data-loader to the model of interest.
......@@ -137,17 +137,6 @@ class _DelayedLoadingDataset(Dataset):
sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
"""
return [self.loader.label(k) for k in self.raw_dataset]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), transform_tvtensors(
......@@ -206,7 +195,7 @@ class _CachedDataset(Dataset):
An iterable containing the raw dataset samples representing one of the
database split datasets.
loader
An object instance that can load samples and labels from storage.
An object instance that can load samples and targets from storage.
parallel
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
......@@ -255,16 +244,16 @@ class _CachedDataset(Dataset):
f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb",
)
def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.
def targets(self) -> list[int | list[int]]:
"""Return the integer targets for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
The integer targets for all samples in the dataset.
"""
return [k[1]["label"] for k in self.data]
return [k[1]["target"] for k in self.data]
def __getitem__(self, key: int) -> Sample:
return self.data[key]
......@@ -294,16 +283,16 @@ class _ConcatDataset(Dataset):
for j in range(len(datasets[i]))
]
def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.
def targets(self) -> list[int | list[int]]:
"""Return the integer targets for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
The integer targets for all samples in the dataset.
"""
return list(itertools.chain(*[k.labels() for k in self._datasets]))
return list(itertools.chain(*[k.targets() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample:
i, j = self._indices[key]
......@@ -317,6 +306,145 @@ class _ConcatDataset(Dataset):
yield from dataset
def _make_balanced_random_sampler(
dataset: Dataset,
target: str = "target",
) -> torch.utils.data.WeightedRandomSampler:
"""Generate a pytorch sampler that samples according to class
probabilities.
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
In this implementation, we balance **both** class and dataset-origin
probabilities, what you expect for a truly *equitable* random sampler.
Take this example for illustration:
* Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
* Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
So:
| Dataset | Target | Samples | Weight | Normalised weight |
+---------+--------+---------+--------+-------------------+
| 1 | 0 | 9 | 1/9 | 1/36 |
| 1 | 1 | 1 | 1/1 | 1/4 |
| 2 | 0 | 3 | 1/3 | 1/12 |
| 2 | 1 | 3 | 1/3 | 1/12 |
Legend:
* Weight: the weights computed by this method
* Normalised weight: the weight per sample used by the random sampler,
after normalising the weights by the sum of all weights in the
concatenated dataset, such that the sum of all normalized weights times
the number of samples is 1.
The properties of this algorithm are as follows:
1. The probability of picking a sample from any target is the same (0.5 in
this case). To verify this, notice that the probability of picking a
sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
2. The probability of picking a sample with ``target=0`` from Dataset 2 is
3 times higher than those from Dataset 1. As there are 3 times fewer
samples in Dataset 2 with ``target=0``, this makes choosing samples from
Dataset 1 proportionally less likely.
3. The probability of picking a sample with ``target=1`` from Dataset 2 is
3 times lower than those from Dataset 1. As there are 3 times fewer
samples in Dataset 1 with ``target=1``, this makes choosing samples from
Dataset 2 proportionally less likely.
This function assumes targets are stored on a dictionary entry named
``target`` inside the metadata information for the
:py:data:`.typing.Sample`, and that its value is an integer.
We then instantiate a pytorch sampler using the inverse probabilities (the
more samples in a class, the less likely it becomes to be sampled.
Parameters
----------
dataset
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported.
target
The name of a metadata key pointing to an integer property that allows
balancing the dataset.
Returns
-------
A sampler, to be used in a dataloader equipped with the same dataset
used to calculate the relative sample weights.
Raises
------
RuntimeError
If requested to balance a dataset (single, not-concatenated) without an
existing target.
"""
def _calculate_weights(targets: list[int]) -> list[float]:
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
return [weights[k] for k in targets]
if isinstance(dataset, torch.utils.data.ConcatDataset):
# There are two possible cases: targets/no-targets
metadata_example = dataset.datasets[0][0][1]
if target in metadata_example and isinstance(
metadata_example[target],
int,
):
# there are integer targets, let's balance with those
logger.info(
f"Balancing sample selection probabilities **and** "
f"concatenated-datasets using metadata targets `{target}`",
)
targets = [
k
for ds in dataset.datasets
for k in typing.cast(Dataset, ds).targets()
]
weights = _calculate_weights(targets) # type: ignore
else:
logger.warning(
f"Balancing samples **and** concatenated-datasets "
f"by using dataset totals as `{target}: int` is not true",
)
weights = [
k
for ds in dataset.datasets
for k in len(typing.cast(typing.Sized, ds))
* [1.0 / len(typing.cast(typing.Sized, ds))]
]
pass
else:
metadata_example = dataset[0][1]
if target in metadata_example and isinstance(
metadata_example[target],
int,
):
logger.info(
f"Balancing samples from dataset using metadata "
f"targets `{target}`",
)
weights = _calculate_weights(dataset.targets()) # type: ignore
else:
raise RuntimeError(
f"Cannot balance samples with multiple class targets "
f"({target}: list[int]) or without metadata targets `{target}`",
)
return torch.utils.data.WeightedRandomSampler(
weights,
len(weights),
replacement=True,
)
class ConcatDataModule(lightning.LightningDataModule):
"""A conveninent DataModule with dictionary split loading, mini- batching,
parallelisation and caching, all in one.
......@@ -800,7 +928,7 @@ class CachingDataModule(ConcatDataModule):
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
An object instance that can load samples and labels from storage.
An object instance that can load samples from storage.
**kwargs
List of named parameters matching those of
:py:class:`ConcatDataModule`, other than ``splits``.
......
......@@ -121,7 +121,7 @@ def check_database_split_loading(
name of a dataset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A loader object that knows how to handle full-samples or just labels.
A loader object that knows how to handle full-samples.
limit
Maximum number of samples to check (in each split/dataset
combination) in this dataset. If set to zero, then check
......
......@@ -17,12 +17,12 @@ First parameter
Second parameter
A dictionary containing a named set of meta-data. One the most common is
the ``label`` entry.
the ``target`` entry.
"""
class RawDataLoader:
"""A loader object can load samples and labels from storage."""
"""A loader object can load samples from storage."""
def sample(self, _: typing.Any) -> Sample:
"""Load whole samples from media.
......@@ -77,9 +77,9 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
provide a dunder len method.
"""
def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method")
def targets(self) -> list[int | list[int]]:
"""Return the integer targets for all samples in the dataset."""
raise NotImplementedError("You must implement the `targets()` method")
DataLoader: typing.TypeAlias = torch.utils.data.DataLoader[Sample]
......
......@@ -42,9 +42,8 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
A tuple containing path suffixes to the sample image, target, and mask
to be loaded, within the dataset root folder.
Returns
-------
......@@ -55,7 +54,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
mode="RGB"
)
tensor = tv_tensors.Image(to_tensor(image))
label = tv_tensors.Image(
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
......@@ -70,7 +69,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
)
)
return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type]
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
def make_split(basename: str) -> DatabaseSplit:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment