From 0827ccca3d3af87bf3e6daecaecdde7bf0164096 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Sun, 9 Jul 2023 22:45:00 +0200 Subject: [PATCH] [ptbench.data.datamodule] Implemented typing, added more logging, implemented sampler-balancing depending on sample class prevalence, parallelized dataset caching --- src/ptbench/data/datamodule.py | 424 +++++++++++------- src/ptbench/data/dataset.py | 66 --- .../{raw_data_loader.py => image_utils.py} | 20 +- src/ptbench/data/shenzhen/default.py | 38 +- src/ptbench/data/shenzhen/loader.py | 105 +++++ src/ptbench/data/shenzhen/raw_data_loader.py | 67 --- src/ptbench/data/split.py | 28 +- src/ptbench/data/typing.py | 75 ++++ 8 files changed, 501 insertions(+), 322 deletions(-) delete mode 100644 src/ptbench/data/dataset.py rename src/ptbench/data/{raw_data_loader.py => image_utils.py} (87%) create mode 100644 src/ptbench/data/shenzhen/loader.py delete mode 100644 src/ptbench/data/shenzhen/raw_data_loader.py create mode 100644 src/ptbench/data/typing.py diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index fc9883d2..a7f0a0fe 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -12,8 +12,16 @@ import lightning import torch import torch.utils.data import torchvision.transforms +import tqdm -from tqdm import tqdm +from .typing import ( + DatabaseSplit, + DataLoader, + Dataset, + RawDataLoader, + Sample, + Transform, +) logger = logging.getLogger(__name__) @@ -64,7 +72,7 @@ def _setup_dataloader_multiproc_parameters( return retval -class _DelayedLoadingDataset(torch.utils.data.Dataset): +class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. This list mimics a pytorch Dataset, except raw data loading is done @@ -78,11 +86,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): An iterable containing the raw dataset samples loaded from the database splits. - raw_data_loader - A callable that will take the representation of the samples declared - inside database splits, load the actual raw data (if one exists), and - transform it into a :py:class:`torch.Tensor`containing floats in the - interval ``[0, 1]``, and a dictionary with further metadata attributes. + loader + An object instance that can load samples and labels from storage. transforms A set of transforms that should be applied on-the-fly for this dataset, @@ -92,30 +97,30 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): def __init__( self, split: typing.Sequence[typing.Any], - raw_data_loader: typing.Callable[ - [typing.Any], tuple[torch.Tensor, typing.Mapping] - ], - transforms: typing.Sequence[ - typing.Callable[[torch.Tensor], torch.Tensor] - ] = [], + loader: RawDataLoader, + transforms: typing.Sequence[Transform] = [], ): self.split = split - self.raw_data_loader = raw_data_loader - # Cannot unpack empty list - if len(transforms) > 0: - self.transform = torchvision.transforms.Compose([*transforms]) - else: - self.transform = torchvision.transforms.Compose([]) + self.loader = loader + self.transform = torchvision.transforms.Compose(transforms) + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + return [self.loader.label(k) for k in self.split] - def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: - tensor, metadata = self.raw_data_loader(self.split[key]) + def __getitem__(self, key: int) -> Sample: + tensor, metadata = self.loader.sample(self.split[key]) return self.transform(tensor), metadata def __len__(self): return len(self.split) + def __iter__(self): + for x in range(len(self)): + yield self[x] -class _CachedDataset(torch.utils.data.Dataset): + +class _CachedDataset(Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the split during construction @@ -130,11 +135,14 @@ class _CachedDataset(torch.utils.data.Dataset): An iterable containing the raw dataset samples loaded from the database splits. - raw_data_loader - A callable that will take the representation of the samples declared - inside database splits, load the actual raw data (if one exists), and - transform it into a :py:class:`torch.Tensor`containing floats in the - interval ``[0, 1]``, and a dictionary with further metadata attributes. + loader + An object instance that can load samples and labels from storage. + + parallel + Use multiprocessing for data loading: if set to -1 (default), disables + multiprocessing data loading. Set to 0 to enable as many data loading + instances as processing cores as available in the system. Set to >= 1 + to enable that many multiprocessing instances for data loading. transforms A set of transforms that should be applied to the cached samples for @@ -145,43 +153,96 @@ class _CachedDataset(torch.utils.data.Dataset): def __init__( self, split: typing.Sequence[typing.Any], - raw_data_loader: typing.Callable[ - [typing.Any], tuple[torch.Tensor, typing.Mapping] - ], - transforms: typing.Sequence[ - typing.Callable[[torch.Tensor], torch.Tensor] - ] = [], + loader: RawDataLoader, + parallel: int = -1, + transforms: typing.Sequence[Transform] = [], ): - # Cannot unpack empty list - if len(transforms) > 0: - self.transform = torchvision.transforms.Compose([*transforms]) - else: - self.transform = torchvision.transforms.Compose([]) - - self.data = [raw_data_loader(k) for k in tqdm(split)] + self.transform = torchvision.transforms.Compose(transforms) - def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: + if parallel < 0: + self.data = [ + loader.sample(k) for k in tqdm.tqdm(split, unit="sample") + ] + else: + instances = parallel or multiprocessing.cpu_count() + logger.info(f"Caching dataset using {instances} processes...") + with multiprocessing.Pool(instances) as p: + self.data = list( + tqdm.tqdm(p.imap(loader.sample, split), total=len(split)) + ) + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + return [k[1]["label"] for k in self.data] + + def __getitem__(self, key: int) -> Sample: tensor, metadata = self.data[key] return self.transform(tensor), metadata def __len__(self): return len(self.data) + def __iter__(self): + for x in range(len(self)): + yield self[x] + -def get_sample_weights( - dataset: _DelayedLoadingDataset | _CachedDataset, -) -> torch.Tensor: - """Computes the (inverse) probabilities of samples based on their class. +def _make_balanced_random_sampler( + dataset: Dataset, + target: str = "label", +) -> torch.utils.data.WeightedRandomSampler: + """Generates a pytorch sampler that samples according to class + probabilities. This function takes as input a torch Dataset, and computes the weights to balance each class in the dataset, and the datasets themselves if one passes a :py:class:`torch.utils.data.ConcatDataset`. - This function assumes labels are stored on the second entry of sample and - are integers. If that is not the case, we just balance the overall dataset - w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`). - For single dataset (not a concatenated one) without labels this function is - the same as a no-op. + In this implementation, we balance **both** class and dataset-origin + probabilities, what you expect for a truly *equitable* random sampler. + + Take this example for illustration: + + * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 + * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 + + So: + + | Dataset | Target | Samples | Weight | Normalised weight | + +---------+--------+---------+--------+-------------------+ + | 1 | 0 | 9 | 1/9 | 1/36 | + | 1 | 1 | 1 | 1/1 | 1/4 | + | 2 | 0 | 3 | 1/3 | 1/12 | + | 2 | 1 | 3 | 1/3 | 1/12 | + + Legend: + + * Weight: the weights computed by this method + * Normalised weight: the weight per sample used by the random sampler, + after normalising the weights by the sum of all weights in the + concatenated dataset, such that the sum of all normalized weights times + the number of samples is 1. + + The properties of this algorithm are as follows: + + 1. The probability of picking a sample from any target is the same (0.5 in + this case). To verify this, notice that the probability of picking a + sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. + 2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is + 3 times higher than those from Dataset 1. As there are 3 times less + samples in Dataset 2 with ``target=0``, this makes choosing samples from + Dataset 1 proportionally less likely. + 3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is + 3 times lower than those from Dataset 1. As there are 3 times less + samples in Dataset 1 with ``target=1``, this makes choosing samples from + Dataset 2 proportionally less likely. + + This function assumes targets are stored on a dictionary entry named + ``target`` inside the metadata information for the :py:type:``Sample``, and + that its value is integer. + + We then instantiate a pytorch sampler using the inverse probabilities (the + more samples of a class, the less likely it becomes to be sampled. Parameters @@ -189,42 +250,83 @@ def get_sample_weights( dataset An instance of torch Dataset. - :py:class:`torch.utils.data.ConcatDataset` are supported + :py:class:`torch.utils.data.ConcatDataset` are supported. + + target + The name of a metadata key pointing to an integer property that allows + balancing the dataset. Returns ------- - sample_weights - The weights for all the samples in the dataset given as input + sampler + A sampler, to be used in a dataloader equipped with the same dataset + used to calculate the relative sample weights. + + + Raises + ------ + + RuntimeError + If requested to balance a dataset (single, not-concatenated) without an + existing target. """ - retval = [] - - def _calculate_dataset_weights( - d: _DelayedLoadingDataset | _CachedDataset, - ) -> torch.Tensor: - """Calculates the weights for one dataset.""" - if "label" in d[0][1] and isinstance(d[0][1]["label"], int): - # there are labels - targets = [k[1]["label"] for k in d] - counts = collections.Counter(targets) - weights = {k: 1.0 / v for k, v in counts.items()} - weight_per_sample = [weights[k[1]] for k in d] - return torch.tensor(weight_per_sample) - else: - # no labels, weight dataset samples only by count - # n.b.: only works if __len__(d) is implemented, we turn off typechecking - weight = 1.0 / len(d) - return torch.tensor(len(d) * [weight]) + + def _calculate_weights(targets: list[int]) -> list[float]: + counts = collections.Counter(targets) + weights = {k: 1.0 / v for k, v in counts.items()} + return [weights[k] for k in targets] if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - retval.append(_calculate_dataset_weights(ds)) # type: ignore[arg-type] + # There are two possible cases: targets/no-targets + metadata_example = dataset.datasets[0][0][1] + if target in metadata_example and isinstance( + metadata_example[target], int + ): + # there are integer targets, let's balance with those + logger.info( + f"Balancing sample selection probabilities **and** " + f"concatenated-datasets using metadata targets `{target}`" + ) + targets = [ + k + for ds in dataset.datasets + for k in typing.cast(Dataset, ds).labels() + ] + weights = _calculate_weights(targets) + else: + logger.warning( + f"Balancing samples **and** concatenated-datasets " + f"WITHOUT metadata targets (`{target}` not available)" + ) + weights = [ + k + for ds in dataset.datasets + for k in len(typing.cast(typing.Sized, ds)) + * [1.0 / len(typing.cast(typing.Sized, ds))] + ] - # Concatenate sample weights from all the datasets - return torch.cat(retval) + pass - return _calculate_dataset_weights(dataset) + else: + metadata_example = dataset[0][1] + if target in metadata_example and isinstance( + metadata_example[target], int + ): + logger.info( + f"Balancing samples from dataset using metadata " + f"targets `{target}`" + ) + weights = _calculate_weights(dataset.labels()) + else: + raise RuntimeError( + f"Cannot balance samples without metadata targets `{target}`" + ) + + return torch.utils.data.WeightedRandomSampler( + weights, len(weights), replacement=True + ) class CachingDataModule(lightning.LightningDataModule): @@ -253,10 +355,10 @@ class CachingDataModule(lightning.LightningDataModule): database_split A dictionary that contains string keys representing subset names, and values that are iterables over sample representations (potentially on - disk). These objects are passed to the ``raw_data_loader`` for loading + disk). These objects are passed to the ``sample_loader`` for loading the sample data (and metadata) in memory. The objects represented may be of any format (e.g. list, dictionary, etc), for as long as the - ``raw_data_loader`` can properly handle it. To check the split and the + ``sample_loader`` can properly handle it. To check the split and the loader function works correctly, you may use :py:func:`..dataset.check_database_split_loading`. As is, this class expects at least one entry called ``train`` to exist in the input @@ -265,12 +367,8 @@ class CachingDataModule(lightning.LightningDataModule): influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. - raw_data_loader - A callable that will take the representation of the samples declared - inside database splits, load the actual raw data (if one exists), and - transform it into a :py:class:`torch.Tensor`containing floats in the - interval ``[0, 1]``, and a dictionary with further metadata attributes. - Samples can be cached **after** raw data loading. + loader + An object instance that can load samples and labels from storage. cache_samples If set, then issue raw data loading during ``prepare_data()``, and @@ -280,25 +378,10 @@ class CachingDataModule(lightning.LightningDataModule): this attribute to ``True``. It is typicall useful for relatively small datasets. - train_sampler - If set, then reset the default random sampler from torch with a - (potentially) biased one of your choice. Notice this will not play an - effect on the batching strategy, only on the way the samples are picked - from the original dataset. A typical replacement is: - - .. code:: python - - sample_weights = get_sample_weights(dataset) - train_sampler = torch.utils.data.WeightedRandomSampler( - sample_weights, len(sample_weights), replacement=True - ) - - The variable ``sample_weights`` represents the probabilities (may not - sum to one) of picking each particular sample in the dataset for a - batch. We typically set ``replacement=True`` to avoid issues with - missing data from one of the classes in mini-batches. This is - particularly important in highly unbalanced datasets. The function - :py:func:`get_sample_weights` may help you in this aspect. + balance_sampler_by_class + If set, then modifies the random sampler used during training and + validation to balance sample picking probability, making sample + across classes **and** datasets equitable. model_transforms A list of transforms (torch modules) that will be applied after @@ -346,17 +429,15 @@ class CachingDataModule(lightning.LightningDataModule): to enable that many multiprocessing instances for data loading. """ + DatasetDictionary = dict[str, Dataset] + def __init__( self, - database_split: dict[str, typing.Sequence[typing.Any]], - raw_data_loader: typing.Callable[ - [typing.Any], tuple[torch.Tensor, typing.Mapping] - ], + database_split: DatabaseSplit, + raw_data_loader: RawDataLoader, cache_samples: bool = False, - train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - model_transforms: list[ - typing.Callable[[torch.Tensor], torch.Tensor] - ] = [], + balance_sampler_by_class: bool = False, + model_transforms: list[Transform] = [], batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, @@ -369,7 +450,8 @@ class CachingDataModule(lightning.LightningDataModule): self.database_split = database_split self.raw_data_loader = raw_data_loader self.cache_samples = cache_samples - self.train_sampler = train_sampler + self._train_sampler = None + self.balance_sampler_by_class = balance_sampler_by_class self.model_transforms = model_transforms self.drop_incomplete_batch = drop_incomplete_batch @@ -380,11 +462,18 @@ class CachingDataModule(lightning.LightningDataModule): ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] + self._datasets: CachingDataModule.DatasetDictionary = {} @property def parallel(self) -> int: - """The parallel property.""" + """Whether to use multiprocessing for data loading. + + Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as + many data loading instances as processing cores as available in + the system. Set to >= 1 to enable that many multiprocessing + instances for data loading. + """ return self._parallel @parallel.setter @@ -394,7 +483,28 @@ class CachingDataModule(lightning.LightningDataModule): value ) # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] + self._datasets = {} + + @property + def balance_sampler_by_class(self): + """Whether to balance samples across labels/datasets. + + If set, then modifies the random sampler used during training + and validation to balance sample picking probability, making + sample across classes **and** datasets equitable. + """ + return self._train_sampler is not None + + @balance_sampler_by_class.setter + def balance_sampler_by_class(self, value: bool): + if value: + if "train" not in self._datasets: + self._setup_dataset("train") + self._train_sampler = _make_balanced_random_sampler( + self._datasets["train"] + ) + else: + self._train_sampler = None def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently sets the batch-chunk-size after validation. @@ -450,22 +560,39 @@ class CachingDataModule(lightning.LightningDataModule): """ if name in self._datasets: - logger.info(f"Dataset {name} is already setup. Not reloading it.") + logger.info( + f"Dataset `{name}` is already setup. " + f"Not re-instantiating it." + ) return if self.cache_samples: - logger.info(f"Caching {name} dataset") + logger.info( + f"Loading dataset:`{name}` into memory (caching)." + f" Trade-off: CPU RAM: more | Disk: less" + ) self._datasets[name] = _CachedDataset( self.database_split[name], self.raw_data_loader, + self.parallel, self.model_transforms, ) else: + logger.info( + f"Loading dataset:`{name}` without caching." + f" Trade-off: CPU RAM: less | Disk: more" + ) self._datasets[name] = _DelayedLoadingDataset( self.database_split[name], self.raw_data_loader, self.model_transforms, ) + def _val_dataset_keys(self) -> list[str]: + """Returns list of validation dataset names.""" + return ["validation"] + [ + k for k in self.database_split.keys() if k.startswith("monitor-") + ] + def setup(self, stage: str) -> None: """Sets up datasets for different tasks on the pipeline. @@ -493,17 +620,12 @@ class CachingDataModule(lightning.LightningDataModule): """ if stage == "fit": - self._setup_dataset("train") - self._setup_dataset("validation") - for k in self.database_split: - if k.startswith("monitor-"): - self._setup_dataset(k) + for k in ["train"] + self._val_dataset_keys(): + self._setup_dataset(k) elif stage == "validate": - self._setup_dataset("validation") - for k in self.database_split: - if k.startswith("monitor-"): - self._setup_dataset(k) + for k in self._val_dataset_keys(): + self._setup_dataset(k) elif stage == "test": self._setup_dataset("test") @@ -535,32 +657,33 @@ class CachingDataModule(lightning.LightningDataModule): * ``predict``: uses only the test dataset """ - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] + self._datasets = {} - def train_dataloader(self) -> torch.utils.data.DataLoader: + def train_dataloader(self) -> DataLoader: """Returns the train data loader.""" return torch.utils.data.DataLoader( self._datasets["train"], - shuffle=True, + shuffle=(self._train_sampler is None), batch_size=self._chunk_size, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, - sampler=self.train_sampler, + sampler=self._train_sampler, **self._dataloader_multiproc, ) - def available_dataset_keys(self) -> typing.KeysView[str]: - """Returns all names for datasets that are setup.""" - return self._datasets.keys() + def unshuffled_train_dataloader(self) -> DataLoader: + """Returns the train data loader without shuffling.""" - def val_database_split_keys(self) -> list[str]: - """Returns list of validation dataset names.""" - return ["validation"] + [ - k for k in self.database_split.keys() if k.startswith("monitor-") - ] + return torch.utils.data.DataLoader( + self._datasets["train"], + shuffle=False, + batch_size=self._chunk_size, + drop_last=False, + **self._dataloader_multiproc, + ) - def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]: + def val_dataloader(self) -> dict[str, DataLoader]: """Returns the validation data loader(s)""" validation_loader_opts = { @@ -571,27 +694,28 @@ class CachingDataModule(lightning.LightningDataModule): } validation_loader_opts.update(self._dataloader_multiproc) - # select all keys of interest return { k: torch.utils.data.DataLoader( self._datasets[k], **validation_loader_opts ) - for k in self.val_database_split_keys() + for k in self._val_dataset_keys() } - def test_dataloader(self): + def test_dataloader(self) -> dict[str, DataLoader]: """Returns the test data loader(s)""" - return torch.utils.data.DataLoader( - self._datasets["test"], - batch_size=self._chunk_size, - shuffle=False, - drop_last=self.drop_incomplete_batch, - pin_memory=self.pin_memory, - **self._dataloader_multiproc, + return dict( + test=torch.utils.data.DataLoader( + self._datasets["test"], + batch_size=self._chunk_size, + shuffle=False, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + **self._dataloader_multiproc, + ) ) - def predict_dataloader(self): + def predict_dataloader(self) -> dict[str, DataLoader]: """Returns the prediction data loader(s)""" return self.test_dataloader() diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py deleted file mode 100644 index 3529ee7f..00000000 --- a/src/ptbench/data/dataset.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import logging - -import torch -import torch.utils.data - -logger = logging.getLogger(__name__) - - -def _get_positive_weights(dataloader): - """Compute the positive weights of each class of the dataset to balance the - BCEWithLogitsLoss criterion. - - This function takes as input a :py:class:`torch.utils.data.DataLoader` - and computes the positive weights of each class to use them to have - a balanced loss. - - - Parameters - ---------- - - dataloader : :py:class:`torch.utils.data.DataLoader` - A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__(). - - - Returns - ------- - - positive_weights : :py:class:`torch.Tensor` - the positive weight of each class in the dataset given as input - """ - targets = [] - - for batch in dataloader: - targets.extend(batch[1]["label"]) - - targets = torch.tensor(targets) - - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] - - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] - ).reshape(-1) - - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) - - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) - - return positive_weights diff --git a/src/ptbench/data/raw_data_loader.py b/src/ptbench/data/image_utils.py similarity index 87% rename from src/ptbench/data/raw_data_loader.py rename to src/ptbench/data/image_utils.py index d852743f..ac31b9ce 100644 --- a/src/ptbench/data/raw_data_loader.py +++ b/src/ptbench/data/image_utils.py @@ -5,6 +5,8 @@ """Data loading code.""" +import pathlib + import numpy import PIL.Image @@ -41,58 +43,58 @@ class RemoveBlackBorders: return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) -def load_pil(path): +def load_pil(path: str | pathlib.Path) -> PIL.Image.Image: """Loads a sample data. Parameters ---------- - path : str + path The full path leading to the image to be loaded Returns ------- - image : PIL.Image.Image + image A PIL image """ return PIL.Image.open(path) -def load_pil_baw(path): +def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image: """Loads a sample data. Parameters ---------- - path : str + path The full path leading to the image to be loaded Returns ------- - image : PIL.Image.Image + image A PIL image in grayscale mode """ return load_pil(path).convert("L") -def load_pil_rgb(path): +def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image: """Loads a sample data. Parameters ---------- - path : str + path The full path leading to the image to be loaded Returns ------- - image : PIL.Image.Image + image A PIL image in RGB mode """ return load_pil(path).convert("RGB") diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 793c3d41..bfe93f44 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -4,19 +4,38 @@ """Shenzhen datamodule for computer-aided diagnosis (default protocol) -See :py:mod:`ptbench.data.shenzhen` for dataset details. +See :py:mod:`ptbench.data.shenzhen` for more database details. This configuration: -* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms` -* augmentations: elastic deformation (probability = 80%) -* output image resolution: 512x512 pixels + +* Raw data input (on disk): + + * PNG images (black and white, encoded as color images) + * Variable width and height: + + * widths: from 1130 to 3001 pixels + * heights: from 948 to 3001 pixels + +* Output image: + + * Transforms: + + * Load raw PNG with :py:mod:`PIL` + * Remove black borders + * Torch resizing(512px, 512px) + * Torch center cropping (512px, 512px) + + * Final specifications: + + * Fixed resolution: 512x512 pixels + * Color RGB encoding """ import importlib.resources from ..datamodule import CachingDataModule from ..split import JSONDatabaseSplit -from .raw_data_loader import raw_data_loader +from .loader import RawDataLoader datamodule = CachingDataModule( database_split=JSONDatabaseSplit( @@ -24,12 +43,5 @@ datamodule = CachingDataModule( "default.json.bz2" ) ), - raw_data_loader=raw_data_loader, - cache_samples=False, - # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - # model_transforms = [], - # batch_size = 1, - # batch_chunk_count = 1, - # drop_incomplete_batch = False, - # parallel = -1, + raw_data_loader=RawDataLoader(), ) diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py new file mode 100644 index 00000000..e983fe70 --- /dev/null +++ b/src/ptbench/data/shenzhen/loader.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for computer-aided diagnosis. + +The standard digital image database for Tuberculosis is created by the National +Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s +Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from +out-patient clinics, and were captured as part of the daily routine using +Philips DR Digital Diagnose systems. + +* Reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* Split reference: none +* Protocol ``default``: + + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) +""" + +import os + +import torchvision.transforms + +from ...utils.rc import load_rc +from ..image_utils import RemoveBlackBorders, load_pil_baw +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Shenzen dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + + transform + Transforms that are always applied to the loaded raw images. + """ + + datadir: str + transform: torchvision.transforms.Compose + + def __init__(self): + self.datadir = load_rc().get( + "datadir.shenzhen", os.path.realpath(os.curdir) + ) + + self.transform = torchvision.transforms.Compose( + [ + RemoveBlackBorders(), + torchvision.transforms.Resize(512), + torchvision.transforms.CenterCrop(512), + torchvision.transforms.ToTensor(), + ] + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + tensor = self.transform( + load_pil_baw(os.path.join(self.datadir, sample[0])) + ) + return tensor, dict(label=sample[1]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] diff --git a/src/ptbench/data/shenzhen/raw_data_loader.py b/src/ptbench/data/shenzhen/raw_data_loader.py deleted file mode 100644 index 0c2c39df..00000000 --- a/src/ptbench/data/shenzhen/raw_data_loader.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the National -Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s -Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from -out-patient clinics, and were captured as part of the daily routine using -Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" - -import os -import typing - -import torch.nn -import torchvision.transforms - -from ...utils.rc import load_rc -from ..raw_data_loader import RemoveBlackBorders, load_pil_baw - -_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) -"""This variable contains the base directory where the database raw data is -stored.""" - -_transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] -) -"""Transforms that are always applied to the loaded raw images.""" - - -def raw_data_loader( - sample: tuple[str, int] -) -> tuple[torch.Tensor, typing.Mapping]: - """Loads a single image sample from the disk. - - Parameters - ---------- - - img_path - The path suffix, within the dataset root folder, where to find the - image to be loaded. - - - Returns - ------- - - image - A PIL image in grayscale mode - """ - tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0]))) - return tensor, dict(label=sample[1]) # type: ignore[arg-type] diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py index 06e813ee..78fe7e33 100644 --- a/src/ptbench/data/split.py +++ b/src/ptbench/data/split.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import collections.abc import csv import importlib.abc import json @@ -12,13 +11,12 @@ import typing import torch +from .typing import DatabaseSplit, RawDataLoader + logger = logging.getLogger(__name__) -class JSONDatabaseSplit( - dict, - typing.Mapping[str, typing.Sequence[typing.Any]], -): +class JSONDatabaseSplit(DatabaseSplit): """Defines a loader that understands a database split (train, test, etc) in JSON format. @@ -73,7 +71,7 @@ class JSONDatabaseSplit( self.path = path self.subsets = self._load_split_from_disk() - def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]: + def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. This method will load JSON information for the current split and return @@ -109,7 +107,7 @@ class JSONDatabaseSplit( return len(self.subsets) -class CSVDatabaseSplit(collections.abc.Mapping): +class CSVDatabaseSplit(DatabaseSplit): """Defines a loader that understands a database split (train, test, etc) in CSV format. @@ -154,9 +152,7 @@ class CSVDatabaseSplit(collections.abc.Mapping): self.directory = directory self.subsets = self._load_split_from_disk() - def _load_split_from_disk( - self, - ) -> dict[str, list[typing.Any]]: + def _load_split_from_disk(self) -> DatabaseSplit: """Loads all subsets in a split from its file system representation. This method will load CSV information for the current split and return all @@ -171,7 +167,7 @@ class CSVDatabaseSplit(collections.abc.Mapping): A dictionary mapping subset names to lists of JSON objects """ - retval = {} + retval: DatabaseSplit = {} for subset in self.directory.iterdir(): if str(subset).endswith(".csv.bz2"): logger.debug(f"Loading database split from {subset}...") @@ -204,8 +200,8 @@ class CSVDatabaseSplit(collections.abc.Mapping): def check_database_split_loading( - database_split: typing.Mapping[str, typing.Sequence[typing.Any]], - loader: typing.Callable[[typing.Any], torch.Tensor], + database_split: DatabaseSplit, + loader: RawDataLoader, limit: int = 0, ) -> int: """For each subset in the split, check if all data can be correctly loaded @@ -224,9 +220,7 @@ def check_database_split_loading( object that represents a single sample. loader - A callable that transforms sample entries in the database split - into :py:class:`torch.Tensor` objects that can be used for training - or inference. + A loader object that knows how to handle full-samples or just labels. limit Maximum number of samples to check (in each split/subset @@ -248,7 +242,7 @@ def check_database_split_loading( samples = subset if not limit else subset[:limit] for pos, sample in enumerate(samples): try: - data = loader(sample) + data, _ = loader.sample(sample) assert isinstance(data, torch.Tensor) except Exception as e: logger.info( diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py new file mode 100644 index 00000000..344c1294 --- /dev/null +++ b/src/ptbench/data/typing.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Defines most common types used in code.""" + +import typing + +import torch +import torch.utils.data + +Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]] +"""Definition of a sample. + +First parameter + The actual data that is input to the model + +Second parameter + A dictionary containing a named set of meta-data. One the most common is + the ``label`` entry. +""" + + +class RawDataLoader: + """A loader object can load samples and labels from storage.""" + + def sample(self, _: typing.Any) -> Sample: + """Loads whole samples from media.""" + raise NotImplementedError("You must implement the `sample()` method") + + def label(self, k: typing.Any) -> int: + """Loads only sample label from media. + + If you do not override this implementation, then, by default, + this method will call :py:meth:`sample` to load the whole sample + and extract the label. + """ + return self.sample(k)[1]["label"] + + +Transform = typing.Callable[[torch.Tensor], torch.Tensor] +"""A callable, that transforms tensors into (other) tensors. + +Typically used in data-processing pipelines inside pytorch. +""" + +TransformSequence = typing.Sequence[Transform] +"""A sequence of transforms.""" + +DatabaseSplit = dict[str, typing.Sequence[typing.Any]] +"""The definition of a database script. + +A database script maps subset names to sequences of objects that, +through RawDataLoader's eventually become Samples in the processing +pipeline. +""" + + +class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): + """Our own definition of a pytorch Dataset, with interesting properties. + + We iterate over Sample objects in this case. Our datasets always + provide a dunder len method. + """ + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + raise NotImplementedError("You must implement the `labels()` method") + + +DataLoader = torch.utils.data.DataLoader[Sample] +"""Our own augmentation definition of a pytorch DataLoader. + +We iterate over Sample objects in this case. +""" -- GitLab