diff --git a/src/mednet/libs/classification/config/data/hivtb/datamodule.py b/src/mednet/libs/classification/config/data/hivtb/datamodule.py index 5d8170b77d50d658b22f3716ed2890a2eb395545..a93606bb6547b5cd196ce58100613a41ed130169 100644 --- a/src/mednet/libs/classification/config/data/hivtb/datamodule.py +++ b/src/mednet/libs/classification/config/data/hivtb/datamodule.py @@ -51,7 +51,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -67,22 +67,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, int]) -> int: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- int - The integer label associated with the sample. + The integer target associated with the sample. """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/montgomery/datamodule.py b/src/mednet/libs/classification/config/data/montgomery/datamodule.py index 83753205a142a5c06548527dc224a75849718768..2afbb8d6cd478371bff6f2828a737a5d93959183 100644 --- a/src/mednet/libs/classification/config/data/montgomery/datamodule.py +++ b/src/mednet/libs/classification/config/data/montgomery/datamodule.py @@ -57,7 +57,7 @@ class RawDataLoader(_BaseRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -75,22 +75,22 @@ class RawDataLoader(_BaseRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, int]) -> int: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- int - The integer label associated with the sample. + The integer target associated with the sample. """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py index 900dcc9d0091a8571629bf283ee6c1829fa62286..0edb1aef91c0f0a19acbd83bf6dfcc1103b89fae 100644 --- a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py @@ -74,7 +74,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -99,22 +99,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, list[int]]) -> list[int]: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, list[int]]) -> list[int]: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the - sample label. + sample target. Returns ------- list[int] - The integer labels associated with the sample. + The integer targets associated with the sample. """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/padchest/datamodule.py b/src/mednet/libs/classification/config/data/padchest/datamodule.py index 0fed49ab85d409ed6ae4b248de849cb31851a3eb..51a20bbdbc9392ce7917cb44b6db3a0233f9a857 100644 --- a/src/mednet/libs/classification/config/data/padchest/datamodule.py +++ b/src/mednet/libs/classification/config/data/padchest/datamodule.py @@ -52,7 +52,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -70,22 +70,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, int | list[int]]) -> int | list[int]: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, int | list[int]]) -> int | list[int]: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- list[int] - The integer labels associated with the sample. + The integer targets associated with the sample. """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py index 7ab4d171825e4894c4d25eb0dfe35c72ba7229a9..45aa327e06abbb22d9f7bbbc6f957a49291a4a06 100644 --- a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py +++ b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py @@ -57,7 +57,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -75,22 +75,22 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, int]) -> int: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- int - The integer label associated with the sample. + The integer target associated with the sample. """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py index 2ba9f688a1abc5a3fc38a45ae92d4d6f7f2d38be..f92afc9d2e2ac141f7c4889db32e72c187f93f13 100644 --- a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py +++ b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py @@ -47,7 +47,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- @@ -65,21 +65,21 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=sample[1], name=sample[0]) # type: ignore[arg-type] - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. + def target(self, sample: tuple[str, int]) -> int: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- - The integer label associated with the sample + The integer target associated with the sample """ return sample[1] diff --git a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py index fb86030754971e0dc6551009d161a4965f029bac..47e97db8d085176e94c9d111ee17d63dece108dc 100644 --- a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py +++ b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py @@ -199,7 +199,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, an integer, representing the - sample label, and possible radiological findings represented by + sample target, and possible radiological findings represented by bounding boxes. Returns @@ -222,26 +222,26 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): ) return tensor, dict( - label=sample[1], name=sample[0], + target=sample[1], bounding_boxes=self.bounding_boxes(sample), ) - def label(self, sample: DatabaseSample) -> int: - """Load a single image sample label from the disk. + def target(self, sample: DatabaseSample) -> int: + """Load a single image sample target from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, an integer, representing the - sample label, and possible radiological findings represented by + sample target, and possible radiological findings represented by bounding boxes. Returns ------- int - The integer label associated with the sample. + The integer target associated with the sample. """ return sample[1] @@ -254,7 +254,7 @@ class ClassificationRawDataLoader(_ClassificationRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, an integer, representing the - sample label, and possible radiological findings represented by + sample target, and possible radiological findings represented by bounding boxes. Returns diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index 03802839900ee4bd3f5d8b368cb8d611cbc6b006..674046ca050d695ff3c81916e1e8fc2b31fe0367 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -13,6 +13,7 @@ import torchvision.models as models import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model + from .separate import separate from .transforms import RGB, SquareCenterPad @@ -129,7 +130,7 @@ class Alexnet(Model): def training_step(self, batch, _): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage @@ -143,7 +144,7 @@ class Alexnet(Model): def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index 69ac44b72d22b9cd0ca932e00a6dee51fd1f93bd..a4e948609cb5f4273fab491985d2d85c7339000c 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -13,6 +13,7 @@ import torchvision.models as models import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model + from .separate import separate from .transforms import RGB, SquareCenterPad @@ -136,7 +137,7 @@ class Densenet(Model): def training_step(self, batch, _): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage @@ -150,7 +151,7 @@ class Densenet(Model): def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage diff --git a/src/mednet/libs/classification/models/loss_weights.py b/src/mednet/libs/classification/models/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5f646e322fcdf4ac60303ff41bf35e0037e2074f --- /dev/null +++ b/src/mednet/libs/classification/models/loss_weights.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging + +import torch +import torch.utils.data +from mednet.libs.common.data.typing import DataLoader + +logger = logging.getLogger("mednet") + + +def _get_label_weights( + dataloader: torch.utils.data.DataLoader, +) -> torch.Tensor: + """Compute the weights of each class of a DataLoader. + + This function inputs a pytorch DataLoader and computes the ratio between + number of negative and positive samples (scalar). The weight can be used + to adjust minimisation criteria to in cases there is a huge data imbalance. + + It returns a vector with weights (inverse counts) for each label. + + Parameters + ---------- + dataloader + A DataLoader from which to compute the positive weights. Entries must + be a dictionary which must contain a ``label`` key. + + Returns + ------- + torch.Tensor + The positive weight of each class in the dataset given as input. + """ + + targets = torch.tensor( + [sample for batch in dataloader for sample in batch[1]["target"]], + ) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]], + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights + + +def make_balanced_bcewithlogitsloss( + dataloader: DataLoader, +) -> torch.nn.BCEWithLogitsLoss: + """Return a balanced binary-cross-entropy loss. + + The loss is weighted using the ratio between positives and total examples + available. + + Parameters + ---------- + dataloader + The DataLoader to use to compute the BCE weights. + + Returns + ------- + torch.nn.BCEWithLogitsLoss + An instance of the weighted loss. + """ + + weights = _get_label_weights(dataloader) + return torch.nn.BCEWithLogitsLoss(pos_weight=weights) diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index 35953ddfd14377ce9a88267a154fbe99a6b8f9b2..e2e47010ed431fc8436591c264bb7911531d5f52 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -13,6 +13,7 @@ import torch.utils.data import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model + from .separate import separate from .transforms import Grayscale, SquareCenterPad @@ -202,7 +203,7 @@ class Pasa(Model): def training_step(self, batch, _): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage @@ -216,7 +217,7 @@ class Pasa(Model): def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] - labels = batch[1]["label"] + labels = batch[1]["target"] # Increase label dimension if too low # Allows single and multiclass usage diff --git a/src/mednet/libs/classification/models/separate.py b/src/mednet/libs/classification/models/separate.py index 5d6065a867ac2bd104d9495a757756a65d2ae6d7..9b575e8ee1a39258b2437e1c968b303aff7b028b 100644 --- a/src/mednet/libs/classification/models/separate.py +++ b/src/mednet/libs/classification/models/separate.py @@ -28,7 +28,7 @@ def _as_predictions( A list of typed predictions that can be saved to disk. """ - return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] + return [(v[1]["name"], v[1]["target"].item(), v[0].item()) for v in samples] def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index 30fc26bd9ce5b5817e4b45db7d31321c3ad5cffe..32e40ca0eaf919e501089eddbfda1768a5a450eb 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -112,7 +112,7 @@ class _DelayedLoadingDataset(Dataset): An iterable containing the raw dataset samples representing one of the database split datasets. loader - An object instance that can load samples and labels from storage. + An object instance that can load samples from storage. transforms A set of transforms that should be applied on-the-fly for this dataset, to fit the output of the raw-data-loader to the model of interest. @@ -137,17 +137,6 @@ class _DelayedLoadingDataset(Dataset): sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") - def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset. - - Returns - ------- - list[int | list[int]] - The integer labels for all samples in the dataset. - """ - - return [self.loader.label(k) for k in self.raw_dataset] - def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) return self.transform(tensor), transform_tvtensors( @@ -206,7 +195,7 @@ class _CachedDataset(Dataset): An iterable containing the raw dataset samples representing one of the database split datasets. loader - An object instance that can load samples and labels from storage. + An object instance that can load samples and targets from storage. parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading @@ -255,16 +244,16 @@ class _CachedDataset(Dataset): f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb", ) - def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset. + def targets(self) -> list[int | list[int]]: + """Return the integer targets for all samples in the dataset. Returns ------- list[int | list[int]] - The integer labels for all samples in the dataset. + The integer targets for all samples in the dataset. """ - return [k[1]["label"] for k in self.data] + return [k[1]["target"] for k in self.data] def __getitem__(self, key: int) -> Sample: return self.data[key] @@ -294,16 +283,16 @@ class _ConcatDataset(Dataset): for j in range(len(datasets[i])) ] - def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset. + def targets(self) -> list[int | list[int]]: + """Return the integer targets for all samples in the dataset. Returns ------- list[int | list[int]] - The integer labels for all samples in the dataset. + The integer targets for all samples in the dataset. """ - return list(itertools.chain(*[k.labels() for k in self._datasets])) + return list(itertools.chain(*[k.targets() for k in self._datasets])) def __getitem__(self, key: int) -> Sample: i, j = self._indices[key] @@ -317,6 +306,145 @@ class _ConcatDataset(Dataset): yield from dataset +def _make_balanced_random_sampler( + dataset: Dataset, + target: str = "target", +) -> torch.utils.data.WeightedRandomSampler: + """Generate a pytorch sampler that samples according to class + probabilities. + + This function takes as input a torch Dataset, and computes the weights to + balance each class in the dataset, and the datasets themselves if one + passes a :py:class:`torch.utils.data.ConcatDataset`. + + In this implementation, we balance **both** class and dataset-origin + probabilities, what you expect for a truly *equitable* random sampler. + + Take this example for illustration: + + * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 + * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 + + So: + + | Dataset | Target | Samples | Weight | Normalised weight | + +---------+--------+---------+--------+-------------------+ + | 1 | 0 | 9 | 1/9 | 1/36 | + | 1 | 1 | 1 | 1/1 | 1/4 | + | 2 | 0 | 3 | 1/3 | 1/12 | + | 2 | 1 | 3 | 1/3 | 1/12 | + + Legend: + + * Weight: the weights computed by this method + * Normalised weight: the weight per sample used by the random sampler, + after normalising the weights by the sum of all weights in the + concatenated dataset, such that the sum of all normalized weights times + the number of samples is 1. + + The properties of this algorithm are as follows: + + 1. The probability of picking a sample from any target is the same (0.5 in + this case). To verify this, notice that the probability of picking a + sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. + 2. The probability of picking a sample with ``target=0`` from Dataset 2 is + 3 times higher than those from Dataset 1. As there are 3 times fewer + samples in Dataset 2 with ``target=0``, this makes choosing samples from + Dataset 1 proportionally less likely. + 3. The probability of picking a sample with ``target=1`` from Dataset 2 is + 3 times lower than those from Dataset 1. As there are 3 times fewer + samples in Dataset 1 with ``target=1``, this makes choosing samples from + Dataset 2 proportionally less likely. + + This function assumes targets are stored on a dictionary entry named + ``target`` inside the metadata information for the + :py:data:`.typing.Sample`, and that its value is an integer. + + We then instantiate a pytorch sampler using the inverse probabilities (the + more samples in a class, the less likely it becomes to be sampled. + + Parameters + ---------- + dataset + An instance of torch Dataset. + :py:class:`torch.utils.data.ConcatDataset` are supported. + target + The name of a metadata key pointing to an integer property that allows + balancing the dataset. + + Returns + ------- + A sampler, to be used in a dataloader equipped with the same dataset + used to calculate the relative sample weights. + + Raises + ------ + RuntimeError + If requested to balance a dataset (single, not-concatenated) without an + existing target. + """ + + def _calculate_weights(targets: list[int]) -> list[float]: + counts = collections.Counter(targets) + weights = {k: 1.0 / v for k, v in counts.items()} + return [weights[k] for k in targets] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + # There are two possible cases: targets/no-targets + metadata_example = dataset.datasets[0][0][1] + if target in metadata_example and isinstance( + metadata_example[target], + int, + ): + # there are integer targets, let's balance with those + logger.info( + f"Balancing sample selection probabilities **and** " + f"concatenated-datasets using metadata targets `{target}`", + ) + targets = [ + k + for ds in dataset.datasets + for k in typing.cast(Dataset, ds).targets() + ] + weights = _calculate_weights(targets) # type: ignore + else: + logger.warning( + f"Balancing samples **and** concatenated-datasets " + f"by using dataset totals as `{target}: int` is not true", + ) + weights = [ + k + for ds in dataset.datasets + for k in len(typing.cast(typing.Sized, ds)) + * [1.0 / len(typing.cast(typing.Sized, ds))] + ] + + pass + + else: + metadata_example = dataset[0][1] + if target in metadata_example and isinstance( + metadata_example[target], + int, + ): + logger.info( + f"Balancing samples from dataset using metadata " + f"targets `{target}`", + ) + weights = _calculate_weights(dataset.targets()) # type: ignore + else: + raise RuntimeError( + f"Cannot balance samples with multiple class targets " + f"({target}: list[int]) or without metadata targets `{target}`", + ) + + return torch.utils.data.WeightedRandomSampler( + weights, + len(weights), + replacement=True, + ) + + class ConcatDataModule(lightning.LightningDataModule): """A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one. @@ -800,7 +928,7 @@ class CachingDataModule(ConcatDataModule): not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. raw_data_loader - An object instance that can load samples and labels from storage. + An object instance that can load samples from storage. **kwargs List of named parameters matching those of :py:class:`ConcatDataModule`, other than ``splits``. diff --git a/src/mednet/libs/common/data/split.py b/src/mednet/libs/common/data/split.py index 64c23fed4a8811ed788d2f4edb0ecab1c43b308d..14d0ac084533d0e79d327e23ec66d4a4b863db8c 100644 --- a/src/mednet/libs/common/data/split.py +++ b/src/mednet/libs/common/data/split.py @@ -121,7 +121,7 @@ def check_database_split_loading( name of a dataset in the split. Each value is a (potentially complex) object that represents a single sample. loader - A loader object that knows how to handle full-samples or just labels. + A loader object that knows how to handle full-samples. limit Maximum number of samples to check (in each split/dataset combination) in this dataset. If set to zero, then check diff --git a/src/mednet/libs/common/data/typing.py b/src/mednet/libs/common/data/typing.py index 48cce9e780dd1be756d5f7b85947feda2b7dc0be..521c8026f57dcc797bae31c239a3357a123f835c 100644 --- a/src/mednet/libs/common/data/typing.py +++ b/src/mednet/libs/common/data/typing.py @@ -17,12 +17,12 @@ First parameter Second parameter A dictionary containing a named set of meta-data. One the most common is - the ``label`` entry. + the ``target`` entry. """ class RawDataLoader: - """A loader object can load samples and labels from storage.""" + """A loader object can load samples from storage.""" def sample(self, _: typing.Any) -> Sample: """Load whole samples from media. @@ -77,9 +77,9 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): provide a dunder len method. """ - def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" - raise NotImplementedError("You must implement the `labels()` method") + def targets(self) -> list[int | list[int]]: + """Return the integer targets for all samples in the dataset.""" + raise NotImplementedError("You must implement the `targets()` method") DataLoader: typing.TypeAlias = torch.utils.data.DataLoader[Sample] diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index 2ef7d1cbd97dd0361ad7d00221b016c3e4d77ab1..893cbbb8682fa248cc5eaf642ed404ad25d804d6 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -42,9 +42,8 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): Parameters ---------- sample - A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. + A tuple containing path suffixes to the sample image, target, and mask + to be loaded, within the dataset root folder. Returns ------- @@ -55,7 +54,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): mode="RGB" ) tensor = tv_tensors.Image(to_tensor(image)) - label = tv_tensors.Image( + target = tv_tensors.Image( to_tensor( PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( mode="1", dither=None @@ -70,7 +69,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ) ) - return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type] + return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] def make_split(basename: str) -> DatabaseSplit: