Skip to content
Snippets Groups Projects
Commit 4a2375d7 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[dataloader] Rename label to target

parent 800081d1
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 301 additions and 83 deletions
...@@ -51,7 +51,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -51,7 +51,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -67,22 +67,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -67,22 +67,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int: def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
int int
The integer label associated with the sample. The integer target associated with the sample.
""" """
return sample[1] return sample[1]
......
...@@ -57,7 +57,7 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -57,7 +57,7 @@ class RawDataLoader(_BaseRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -75,22 +75,22 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -75,22 +75,22 @@ class RawDataLoader(_BaseRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int: def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
int int
The integer label associated with the sample. The integer target associated with the sample.
""" """
return sample[1] return sample[1]
......
...@@ -74,7 +74,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -74,7 +74,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -99,22 +99,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -99,22 +99,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, list[int]]) -> list[int]: def target(self, sample: tuple[str, list[int]]) -> list[int]:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the where to find the image to be loaded, and an integer, representing the
sample label. sample target.
Returns Returns
------- -------
list[int] list[int]
The integer labels associated with the sample. The integer targets associated with the sample.
""" """
return sample[1] return sample[1]
......
...@@ -52,7 +52,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -52,7 +52,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -70,22 +70,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -70,22 +70,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int | list[int]]) -> int | list[int]: def target(self, sample: tuple[str, int | list[int]]) -> int | list[int]:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
list[int] list[int]
The integer labels associated with the sample. The integer targets associated with the sample.
""" """
return sample[1] return sample[1]
......
...@@ -57,7 +57,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -57,7 +57,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -75,22 +75,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -75,22 +75,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int: def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
int int
The integer label associated with the sample. The integer target associated with the sample.
""" """
return sample[1] return sample[1]
......
...@@ -47,7 +47,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -47,7 +47,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -65,21 +65,21 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -65,21 +65,21 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
# to_pil_image(tensor).show() # to_pil_image(tensor).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int: def target(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
The integer label associated with the sample The integer target associated with the sample
""" """
return sample[1] return sample[1]
......
...@@ -199,7 +199,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -199,7 +199,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by sample target, and possible radiological findings represented by
bounding boxes. bounding boxes.
Returns Returns
...@@ -222,26 +222,26 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -222,26 +222,26 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
) )
return tensor, dict( return tensor, dict(
label=sample[1],
name=sample[0], name=sample[0],
target=sample[1],
bounding_boxes=self.bounding_boxes(sample), bounding_boxes=self.bounding_boxes(sample),
) )
def label(self, sample: DatabaseSample) -> int: def target(self, sample: DatabaseSample) -> int:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by sample target, and possible radiological findings represented by
bounding boxes. bounding boxes.
Returns Returns
------- -------
int int
The integer label associated with the sample. The integer target associated with the sample.
""" """
return sample[1] return sample[1]
...@@ -254,7 +254,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ...@@ -254,7 +254,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, an integer, representing the where to find the image to be loaded, an integer, representing the
sample label, and possible radiological findings represented by sample target, and possible radiological findings represented by
bounding boxes. bounding boxes.
Returns Returns
......
...@@ -13,6 +13,7 @@ import torchvision.models as models ...@@ -13,6 +13,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model from mednet.libs.common.models.model import Model
from .separate import separate from .separate import separate
from .transforms import RGB, SquareCenterPad from .transforms import RGB, SquareCenterPad
...@@ -129,7 +130,7 @@ class Alexnet(Model): ...@@ -129,7 +130,7 @@ class Alexnet(Model):
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -143,7 +144,7 @@ class Alexnet(Model): ...@@ -143,7 +144,7 @@ class Alexnet(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
......
...@@ -13,6 +13,7 @@ import torchvision.models as models ...@@ -13,6 +13,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model from mednet.libs.common.models.model import Model
from .separate import separate from .separate import separate
from .transforms import RGB, SquareCenterPad from .transforms import RGB, SquareCenterPad
...@@ -136,7 +137,7 @@ class Densenet(Model): ...@@ -136,7 +137,7 @@ class Densenet(Model):
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -150,7 +151,7 @@ class Densenet(Model): ...@@ -150,7 +151,7 @@ class Densenet(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import torch
import torch.utils.data
from mednet.libs.common.data.typing import DataLoader
logger = logging.getLogger("mednet")
def _get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
torch.Tensor
The positive weight of each class in the dataset given as input.
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["target"]],
)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Return a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples
available.
Parameters
----------
dataloader
The DataLoader to use to compute the BCE weights.
Returns
-------
torch.nn.BCEWithLogitsLoss
An instance of the weighted loss.
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
...@@ -13,6 +13,7 @@ import torch.utils.data ...@@ -13,6 +13,7 @@ import torch.utils.data
import torchvision.transforms import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model from mednet.libs.common.models.model import Model
from .separate import separate from .separate import separate
from .transforms import Grayscale, SquareCenterPad from .transforms import Grayscale, SquareCenterPad
...@@ -202,7 +203,7 @@ class Pasa(Model): ...@@ -202,7 +203,7 @@ class Pasa(Model):
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -216,7 +217,7 @@ class Pasa(Model): ...@@ -216,7 +217,7 @@ class Pasa(Model):
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["target"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
......
...@@ -28,7 +28,7 @@ def _as_predictions( ...@@ -28,7 +28,7 @@ def _as_predictions(
A list of typed predictions that can be saved to disk. A list of typed predictions that can be saved to disk.
""" """
return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] return [(v[1]["name"], v[1]["target"].item(), v[0].item()) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
......
...@@ -112,7 +112,7 @@ class _DelayedLoadingDataset(Dataset): ...@@ -112,7 +112,7 @@ class _DelayedLoadingDataset(Dataset):
An iterable containing the raw dataset samples representing one of the An iterable containing the raw dataset samples representing one of the
database split datasets. database split datasets.
loader loader
An object instance that can load samples and labels from storage. An object instance that can load samples from storage.
transforms transforms
A set of transforms that should be applied on-the-fly for this dataset, A set of transforms that should be applied on-the-fly for this dataset,
to fit the output of the raw-data-loader to the model of interest. to fit the output of the raw-data-loader to the model of interest.
...@@ -137,17 +137,6 @@ class _DelayedLoadingDataset(Dataset): ...@@ -137,17 +137,6 @@ class _DelayedLoadingDataset(Dataset):
sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
"""
return [self.loader.label(k) for k in self.raw_dataset]
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.raw_dataset[key]) tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), transform_tvtensors( return self.transform(tensor), transform_tvtensors(
...@@ -206,7 +195,7 @@ class _CachedDataset(Dataset): ...@@ -206,7 +195,7 @@ class _CachedDataset(Dataset):
An iterable containing the raw dataset samples representing one of the An iterable containing the raw dataset samples representing one of the
database split datasets. database split datasets.
loader loader
An object instance that can load samples and labels from storage. An object instance that can load samples and targets from storage.
parallel parallel
Use multiprocessing for data loading: if set to -1 (default), disables Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading multiprocessing data loading. Set to 0 to enable as many data loading
...@@ -255,16 +244,16 @@ class _CachedDataset(Dataset): ...@@ -255,16 +244,16 @@ class _CachedDataset(Dataset):
f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb", f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb",
) )
def labels(self) -> list[int | list[int]]: def targets(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset. """Return the integer targets for all samples in the dataset.
Returns Returns
------- -------
list[int | list[int]] list[int | list[int]]
The integer labels for all samples in the dataset. The integer targets for all samples in the dataset.
""" """
return [k[1]["label"] for k in self.data] return [k[1]["target"] for k in self.data]
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
return self.data[key] return self.data[key]
...@@ -294,16 +283,16 @@ class _ConcatDataset(Dataset): ...@@ -294,16 +283,16 @@ class _ConcatDataset(Dataset):
for j in range(len(datasets[i])) for j in range(len(datasets[i]))
] ]
def labels(self) -> list[int | list[int]]: def targets(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset. """Return the integer targets for all samples in the dataset.
Returns Returns
------- -------
list[int | list[int]] list[int | list[int]]
The integer labels for all samples in the dataset. The integer targets for all samples in the dataset.
""" """
return list(itertools.chain(*[k.labels() for k in self._datasets])) return list(itertools.chain(*[k.targets() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
i, j = self._indices[key] i, j = self._indices[key]
...@@ -317,6 +306,145 @@ class _ConcatDataset(Dataset): ...@@ -317,6 +306,145 @@ class _ConcatDataset(Dataset):
yield from dataset yield from dataset
def _make_balanced_random_sampler(
dataset: Dataset,
target: str = "target",
) -> torch.utils.data.WeightedRandomSampler:
"""Generate a pytorch sampler that samples according to class
probabilities.
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
In this implementation, we balance **both** class and dataset-origin
probabilities, what you expect for a truly *equitable* random sampler.
Take this example for illustration:
* Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
* Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
So:
| Dataset | Target | Samples | Weight | Normalised weight |
+---------+--------+---------+--------+-------------------+
| 1 | 0 | 9 | 1/9 | 1/36 |
| 1 | 1 | 1 | 1/1 | 1/4 |
| 2 | 0 | 3 | 1/3 | 1/12 |
| 2 | 1 | 3 | 1/3 | 1/12 |
Legend:
* Weight: the weights computed by this method
* Normalised weight: the weight per sample used by the random sampler,
after normalising the weights by the sum of all weights in the
concatenated dataset, such that the sum of all normalized weights times
the number of samples is 1.
The properties of this algorithm are as follows:
1. The probability of picking a sample from any target is the same (0.5 in
this case). To verify this, notice that the probability of picking a
sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
2. The probability of picking a sample with ``target=0`` from Dataset 2 is
3 times higher than those from Dataset 1. As there are 3 times fewer
samples in Dataset 2 with ``target=0``, this makes choosing samples from
Dataset 1 proportionally less likely.
3. The probability of picking a sample with ``target=1`` from Dataset 2 is
3 times lower than those from Dataset 1. As there are 3 times fewer
samples in Dataset 1 with ``target=1``, this makes choosing samples from
Dataset 2 proportionally less likely.
This function assumes targets are stored on a dictionary entry named
``target`` inside the metadata information for the
:py:data:`.typing.Sample`, and that its value is an integer.
We then instantiate a pytorch sampler using the inverse probabilities (the
more samples in a class, the less likely it becomes to be sampled.
Parameters
----------
dataset
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported.
target
The name of a metadata key pointing to an integer property that allows
balancing the dataset.
Returns
-------
A sampler, to be used in a dataloader equipped with the same dataset
used to calculate the relative sample weights.
Raises
------
RuntimeError
If requested to balance a dataset (single, not-concatenated) without an
existing target.
"""
def _calculate_weights(targets: list[int]) -> list[float]:
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
return [weights[k] for k in targets]
if isinstance(dataset, torch.utils.data.ConcatDataset):
# There are two possible cases: targets/no-targets
metadata_example = dataset.datasets[0][0][1]
if target in metadata_example and isinstance(
metadata_example[target],
int,
):
# there are integer targets, let's balance with those
logger.info(
f"Balancing sample selection probabilities **and** "
f"concatenated-datasets using metadata targets `{target}`",
)
targets = [
k
for ds in dataset.datasets
for k in typing.cast(Dataset, ds).targets()
]
weights = _calculate_weights(targets) # type: ignore
else:
logger.warning(
f"Balancing samples **and** concatenated-datasets "
f"by using dataset totals as `{target}: int` is not true",
)
weights = [
k
for ds in dataset.datasets
for k in len(typing.cast(typing.Sized, ds))
* [1.0 / len(typing.cast(typing.Sized, ds))]
]
pass
else:
metadata_example = dataset[0][1]
if target in metadata_example and isinstance(
metadata_example[target],
int,
):
logger.info(
f"Balancing samples from dataset using metadata "
f"targets `{target}`",
)
weights = _calculate_weights(dataset.targets()) # type: ignore
else:
raise RuntimeError(
f"Cannot balance samples with multiple class targets "
f"({target}: list[int]) or without metadata targets `{target}`",
)
return torch.utils.data.WeightedRandomSampler(
weights,
len(weights),
replacement=True,
)
class ConcatDataModule(lightning.LightningDataModule): class ConcatDataModule(lightning.LightningDataModule):
"""A conveninent DataModule with dictionary split loading, mini- batching, """A conveninent DataModule with dictionary split loading, mini- batching,
parallelisation and caching, all in one. parallelisation and caching, all in one.
...@@ -800,7 +928,7 @@ class CachingDataModule(ConcatDataModule): ...@@ -800,7 +928,7 @@ class CachingDataModule(ConcatDataModule):
not influence any early stop criteria during training, and are just not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset. monitored beyond the ``validation`` dataset.
raw_data_loader raw_data_loader
An object instance that can load samples and labels from storage. An object instance that can load samples from storage.
**kwargs **kwargs
List of named parameters matching those of List of named parameters matching those of
:py:class:`ConcatDataModule`, other than ``splits``. :py:class:`ConcatDataModule`, other than ``splits``.
......
...@@ -121,7 +121,7 @@ def check_database_split_loading( ...@@ -121,7 +121,7 @@ def check_database_split_loading(
name of a dataset in the split. Each value is a (potentially complex) name of a dataset in the split. Each value is a (potentially complex)
object that represents a single sample. object that represents a single sample.
loader loader
A loader object that knows how to handle full-samples or just labels. A loader object that knows how to handle full-samples.
limit limit
Maximum number of samples to check (in each split/dataset Maximum number of samples to check (in each split/dataset
combination) in this dataset. If set to zero, then check combination) in this dataset. If set to zero, then check
......
...@@ -17,12 +17,12 @@ First parameter ...@@ -17,12 +17,12 @@ First parameter
Second parameter Second parameter
A dictionary containing a named set of meta-data. One the most common is A dictionary containing a named set of meta-data. One the most common is
the ``label`` entry. the ``target`` entry.
""" """
class RawDataLoader: class RawDataLoader:
"""A loader object can load samples and labels from storage.""" """A loader object can load samples from storage."""
def sample(self, _: typing.Any) -> Sample: def sample(self, _: typing.Any) -> Sample:
"""Load whole samples from media. """Load whole samples from media.
...@@ -77,9 +77,9 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): ...@@ -77,9 +77,9 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
provide a dunder len method. provide a dunder len method.
""" """
def labels(self) -> list[int | list[int]]: def targets(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.""" """Return the integer targets for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method") raise NotImplementedError("You must implement the `targets()` method")
DataLoader: typing.TypeAlias = torch.utils.data.DataLoader[Sample] DataLoader: typing.TypeAlias = torch.utils.data.DataLoader[Sample]
......
...@@ -42,9 +42,8 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -42,9 +42,8 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
Parameters Parameters
---------- ----------
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing path suffixes to the sample image, target, and mask
where to find the image to be loaded, and an integer, representing the to be loaded, within the dataset root folder.
sample label.
Returns Returns
------- -------
...@@ -55,7 +54,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -55,7 +54,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
mode="RGB" mode="RGB"
) )
tensor = tv_tensors.Image(to_tensor(image)) tensor = tv_tensors.Image(to_tensor(image))
label = tv_tensors.Image( target = tv_tensors.Image(
to_tensor( to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None mode="1", dither=None
...@@ -70,7 +69,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -70,7 +69,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
) )
) )
return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
def make_split(basename: str) -> DatabaseSplit: def make_split(basename: str) -> DatabaseSplit:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment