From a67626d86cb4238f118e6da7341d87fc28c540a9 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 28 Jun 2023 22:41:05 +0200 Subject: [PATCH] [datamodule] Slightly streamlines the datamodule approach; adds documentation; adds type annotations; adds TODOs --- src/ptbench/data/base_datamodule.py | 196 ---------- src/ptbench/data/datamodule.py | 503 ++++++++++++++++++++++++ src/ptbench/data/dataset.py | 526 ++++++++------------------ src/ptbench/data/shenzhen/__init__.py | 19 - src/ptbench/data/shenzhen/default.py | 43 ++- src/ptbench/data/shenzhen/loader.py | 68 ++++ src/ptbench/data/shenzhen/utils.py | 114 ------ src/ptbench/data/transforms.py | 2 + 8 files changed, 765 insertions(+), 706 deletions(-) delete mode 100644 src/ptbench/data/base_datamodule.py create mode 100644 src/ptbench/data/datamodule.py create mode 100644 src/ptbench/data/shenzhen/loader.py delete mode 100644 src/ptbench/data/shenzhen/utils.py diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py deleted file mode 100644 index 8377c663..00000000 --- a/src/ptbench/data/base_datamodule.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import multiprocessing -import sys - -import lightning as pl -import torch - -from clapper.logging import setup -from torch.utils.data import DataLoader, WeightedRandomSampler -from torchvision import transforms - -from .dataset import get_samples_weights - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class BaseDataModule(pl.LightningDataModule): - def __init__( - self, - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - parallel=-1, - ): - super().__init__() - - self.batch_size = batch_size - self.batch_chunk_count = batch_chunk_count - - self.train_dataset = None - self.validation_dataset = None - self.extra_validation_datasets = None - - self._raw_data_transforms = [] - self._model_transforms = [] - self._augmentation_transforms = [] - - self.drop_incomplete_batch = drop_incomplete_batch - self.pin_memory = ( - torch.cuda.is_available() - ) # should only be true if GPU available and using it - - self.parallel = parallel - - def setup(self, stage: str): - # Implemented by user - # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset - raise NotImplementedError - - def train_dataloader(self): - train_samples_weights = get_samples_weights(self.train_dataset) - - multiproc_kwargs = self._setup_multiproc(self.parallel) - train_sampler = WeightedRandomSampler( - train_samples_weights, len(train_samples_weights), replacement=True - ) - - return DataLoader( - self.train_dataset, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - drop_last=self.drop_incomplete_batch, - pin_memory=self.pin_memory, - sampler=train_sampler, - **multiproc_kwargs, - ) - - def val_dataloader(self): - loaders_dict = {} - - multiproc_kwargs = self._setup_multiproc(self.parallel) - - val_loader = DataLoader( - dataset=self.validation_dataset, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - shuffle=False, - drop_last=False, - pin_memory=self.pin_memory, - **multiproc_kwargs, - ) - - loaders_dict["validation_loader"] = val_loader - - if self.extra_validation_datasets is not None: - for set_idx, extra_set in enumerate(self.extra_validation_datasets): - extra_val_loader = DataLoader( - dataset=extra_set, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - shuffle=False, - drop_last=False, - pin_memory=self.pin_memory, - **multiproc_kwargs, - ) - - loaders_dict[ - f"extra_validation_loader{set_idx}" - ] = extra_val_loader - - return loaders_dict - - def predict_dataloader(self): - loaders_dict = {} - - loaders_dict["train_loader"] = self.train_dataloader() - for k, v in self.val_dataloader().items(): - loaders_dict[k] = v - - return loaders_dict - - def update_module_properties(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - def _compute_chunk_size(self, batch_size, chunk_count): - batch_chunk_size = batch_size - if batch_size % chunk_count != 0: - # batch_size must be divisible by batch_chunk_count. - raise RuntimeError( - f"--batch-size ({batch_size}) must be divisible by " - f"--batch-chunk-size ({chunk_count})." - ) - else: - batch_chunk_size = batch_size // chunk_count - - return batch_chunk_size - - def _setup_multiproc(self, parallel): - multiproc_kwargs = dict() - if parallel < 0: - multiproc_kwargs["num_workers"] = 0 - else: - multiproc_kwargs["num_workers"] = ( - parallel or multiprocessing.cpu_count() - ) - - if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": - multiproc_kwargs[ - "multiprocessing_context" - ] = multiprocessing.get_context("spawn") - - return multiproc_kwargs - - def _build_transforms(self, is_train): - all_transforms = self.raw_data_transforms + self.model_transforms - - if is_train: - all_transforms = all_transforms + self.augmentation_transforms - - # all_transforms.append(transforms.ToTensor()) - - return transforms.Compose(all_transforms) - - @property - def raw_data_transforms(self): - return self._raw_data_transforms - - @raw_data_transforms.setter - def raw_data_transforms(self, transforms): - self._raw_data_transforms = transforms - - @property - def model_transforms(self): - return self._model_transforms - - @model_transforms.setter - def model_transforms(self, transforms): - self._model_transforms = transforms - - @property - def augmentation_transforms(self): - return self._augmentation_transforms - - @augmentation_transforms.setter - def augmentation_transforms(self, transforms): - self._augmentation_transforms = transforms - - -def get_dataset_from_module(module, stage, **module_args): - """Instantiates a DataModule and retrieves the corresponding dataset. - - Useful when combining multiple datasets. - """ - module_instance = module(**module_args) - module_instance.prepare_data() - module_instance.setup(stage=stage) - dataset = module_instance.dataset - - return dataset diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py new file mode 100644 index 00000000..11e76697 --- /dev/null +++ b/src/ptbench/data/datamodule.py @@ -0,0 +1,503 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import collections +import multiprocessing +import sys +import typing + +import lightning +import torch +import torch.utils.data + +from clapper.logging import setup + +# TODO: No logging on this module... +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def _setup_dataloader_multiproc_parameters( + parallel: int, +) -> dict[str, typing.Any]: + """Returns a dictionary containing pytorch arguments to be used in data + loaders. + + It sets the parameter ``num_workers`` to match the expected pytorch + representation. For macOS machines, it also sets the + ``multiprocessing_context`` to use ``spawn`` instead of the default. + + The mapping between the command-line interface ``parallel`` setting works + like this: + + .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation + :widths: 15 15 70 + :header-rows: 1 + + * - CLI ``parallel`` + - :py:class:`torch.utils.data.DataLoader` ``kwargs`` + - Comments + * - ``<0`` + - 0 + - Disables multiprocessing entirely, executes everything within the + same processing context + * - ``0`` + - :py:func:`multiprocessing.cpu_count` + - Runs mini-batch data loading on as many external processes as CPUs + available in the current machine + * - ``>=1`` + - ``parallel`` + - Runs mini-batch data loading on as many external processes as set on + ``parallel`` + """ + + retval: dict[str, typing.Any] = dict() + if parallel < 0: + retval["num_workers"] = 0 + else: + retval["num_workers"] = parallel or multiprocessing.cpu_count() + + if retval["num_workers"] > 0 and sys.platform == "darwin": + retval["multiprocessing_context"] = multiprocessing.get_context("spawn") + + return retval + + +class _DelayedLoadingDataset(torch.utils.data.Dataset): + """A list that loads its samples on demand. + + This list mimics a pytorch Dataset, except raw data loading is done + on-the-fly, as the samples are requested through the bracket operator. + + + Parameters + ---------- + + split + An iterable containing the raw dataset samples loaded from the database + splits. + + raw_data_loader + A callable that will take the representation of the samples declared + inside database splits, load the actual raw data (if one exists), and + transform it into a :py:class:`torch.Tensor`containing floats in the + interval ``[0, 1]``, and a dictionary with further metadata attributes. + + transforms + A set of transforms that should be applied on-the-fly for this dataset. + """ + + def __init__( + self, + split: typing.Sequence[typing.Any], + raw_data_loader: typing.Callable[ + [typing.Any], tuple[torch.Tensor, typing.Mapping] + ], + transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None, + ): + self.split = split + self.raw_data_loader = raw_data_loader + self.transform = torch.nn.Sequential(*transforms) + + def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: + tensor, metadata = self.raw_data_loader(self.split[key]) + return self.transform(tensor), metadata + + def __len__(self): + return len(self.split) + + +class _CachedDataset(torch.utils.data.Dataset): + """Basically, a list of preloaded samples. + + This dataset will load all samples from the split during construction + instead of delaying that to the indexing. + + + Parameters + ---------- + + split + An iterable containing the raw dataset samples loaded from the database + splits. + + raw_data_loader + A callable that will take the representation of the samples declared + inside database splits, load the actual raw data (if one exists), and + transform it into a :py:class:`torch.Tensor`containing floats in the + interval ``[0, 1]``, and a dictionary with further metadata attributes. + + transforms + A set of transforms that should be applied on-the-fly for this dataset. + """ + + def __init__( + self, + split: typing.Sequence[typing.Any], + raw_data_loader: typing.Callable[ + [typing.Any], tuple[torch.Tensor, typing.Mapping] + ], + transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None, + ): + self.data = [raw_data_loader(k) for k in split] + self.transform = torch.nn.Sequential(*transforms) + + def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: + tensor, metadata = self.data[key] + return self.transform(tensor), metadata + + def __len__(self): + return len(self.data) + + +def get_sample_weights( + dataset: _DelayedLoadingDataset | _CachedDataset, +) -> torch.Tensor: + """Computes the (inverse) probabilities of samples based on their class. + + This function takes as input a torch Dataset, and computes the weights to + balance each class in the dataset, and the datasets themselves if one + passes a :py:class:`torch.utils.data.ConcatDataset`. + + This function assumes labels are stored on the second entry of sample and + are integers. If that is not the case, we just balance the overall dataset + w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`). + For single dataset (not a concatenated one) without labels this function is + the same as a no-op. + + + Parameters + ---------- + + dataset + An instance of torch Dataset. + :py:class:`torch.utils.data.ConcatDataset` are supported + + + Returns + ------- + + sample_weights + The weights for all the samples in the dataset given as input + """ + retval = [] + + def _calculate_dataset_weights( + d: _DelayedLoadingDataset | _CachedDataset, + ) -> torch.Tensor: + """Calculates the weights for one dataset.""" + if "label" in d[0][1] and isinstance(d[0][1]["label"], int): + # there are labels + targets = [k[1]["label"] for k in d] + counts = collections.Counter(targets) + weights = {k: 1.0 / v for k, v in counts.items()} + weight_per_sample = [weights[k[1]] for k in d] + return torch.tensor(weight_per_sample) + else: + # no labels, weight dataset samples only by count + # n.b.: only works if __len__(d) is implemented, we turn off typechecking + weight = 1.0 / len(d) + return torch.tensor(len(d) * [weight]) + + if isinstance(dataset, torch.utils.data.ConcatDataset): + for ds in dataset.datasets: + retval.append(_calculate_dataset_weights(ds)) # type: ignore[arg-type] + + # Concatenate sample weights from all the datasets + return torch.cat(retval) + + return _calculate_dataset_weights(dataset) + + +class CachingDataModule(lightning.LightningDataModule): + """A conveninent data module with CSV or JSON protocol loading, mini- + batching, parallelisation and caching, all in one. + + Instances of this class load data-split (a.k.a. protocol) definitions for a + database, and can load the data from the disk. An optional caching + mechanism stores the data at associated CPU memory, which can improve data + serving while training and evaluating models. + + This datamodule defines basic operations to handle data loading and + mini-batch handling within this package's framework. It can return + :py:class:`torch.utils.data.DataLoader` objects for training, validation, + prediction and testing conditions. Parallelisation is handled by a simple + input flag. + + Users must implement the basic :py:meth:`setup` function, which is + parameterised by a single string enumeration containing: ``fit``, + ``validate``, ``test``, or ``predict``. + + + Parameters + ---------- + + database_split + A dictionary that contains string keys representing subset names, and + values that are iterables over sample representations (potentially on + disk). These objects are passed to the ``raw_data_loader`` for loading + the sample data (and metadata) in memory. The objects represented may + be of any format (e.g. list, dictionary, etc), for as long as the + ``raw_data_loader`` can properly handle it. To check the split and the + loader function works correctly, you may use + :py:func:`..dataset.check_database_split_loading`. As is, this class + expects at least one entry called ``train`` to exist in the input + dictionary. Optional entries are ``validation``, and ``test``. Entries + named ``monitor-...`` will be considered extra subsets that do not + influence any early stop criteria during training, and are just + monitored beyond the ``validation`` dataset. + + raw_data_loader + A callable that will take the representation of the samples declared + inside database splits, load the actual raw data (if one exists), and + transform it into a :py:class:`torch.Tensor`containing floats in the + interval ``[0, 1]``, and a dictionary with further metadata attributes. + Samples can be cached **after** raw data loading. + + cache_samples + If set, then issue raw data loading during ``prepare_data()``, and + serves samples from CPU memory. Otherwise, loads samples from disk on + demand. Running from CPU memory will offer increased speeds in exchange + for CPU memory. Sufficient CPU memory must be available before you set + this attribute to ``True``. It is typicall useful for relatively small + datasets. + + train_sampler + If set, then reset the default random sampler from torch with a + (potentially) biased one of your choice. Notice this will not play an + effect on the batching strategy, only on the way the samples are picked + from the original dataset. A typical replacement is: + + .. code:: python + + sample_weights = get_sample_weights(dataset) + train_sampler = torch.utils.data.WeightedRandomSampler( + sample_weights, len(sample_weights), replacement=True + ) + + The variable ``sample_weights`` represents the probabilities (may not + sum to one) of picking each particular sample in the dataset for a + batch. We typically set ``replacement=True`` to avoid issues with + missing data from one of the classes in mini-batches. This is + particularly important in highly unbalanced datasets. The function + :py:func:`get_sample_weights` may help you in this aspect. + + data_augmentations + A list of torchvision transforms (torch modules) that will be applied + on training set samples to create data augmentations during the + training of a model. Augmentation transform pipelines are applied + *after* the raw data is loaded, and before ``model_transforms``. + Augmentation transforms assume they receive a torch tensor representing + an image as input (see :py:class:`torchvision.transforms.ToTensor` for + details), in the range ``[0, 1]``. + + model_transforms + A list of torchvision transforms (torch modules) that will be applied + after data augmentation transforms, and just before data is fed into + the model for all data loaders produced by this data module. This part + of the pipeline receives data as output by the raw-data-loader, or from + data augmentations, if any is specified. + + batch_size + Number of samples in every **training** batch (this parameter affects + memory requirements for the network). If the number of samples in the + batch is larger than the total number of samples available for + training, this value is truncated. If this number is smaller, then + batches of the specified size are created and fed to the network until + there are no more new samples to feed (epoch is finished). If the + total number of training samples is not a multiple of the batch-size, + the last batch will be smaller than the first, unless + ``drop_incomplete_batch`` is set to ``true``, in which case this batch + is not used. + + batch_chunk_count + Number of chunks in every batch (this parameter affects memory + requirements for the network). The number of samples loaded for every + iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` + needs to be divisible by ``batch_chunk_count``, otherwise an error will + be raised. This parameter is used to reduce number of samples loaded in + each iteration, in order to reduce the memory usage in exchange for + processing time (more iterations). This is specially interesting whe + one is running with GPUs with limited RAM. The default of 1 forces the + whole batch to be processed at once. Otherwise the batch is broken into + batch-chunk-count pieces, and gradients are accumulated to complete + each batch. + + drop_incomplete_batch + If set, then may drop the last batch in an epoch, in case it is + incomplete. If you set this option, you should also consider + increasing the total number of epochs of training, as the total number + of training steps may be reduced. + + parallel + Use multiprocessing for data loading: if set to -1 (default), disables + multiprocessing data loading. Set to 0 to enable as many data loading + instances as processing cores as available in the system. Set to >= 1 + to enable that many multiprocessing instances for data loading. + """ + + def __init__( + self, + database_split: dict[str, typing.Sequence[typing.Any]], + raw_data_loader: typing.Callable[ + [typing.Any], tuple[torch.Tensor, typing.Mapping] + ], + cache_samples: bool = False, + train_sampler: typing.Optional[torch.utils.data.Sampler] = None, + data_augmentations: list[torch.nn.Module] = [], + model_transforms: list[torch.nn.Module] = [], + batch_size: int = 1, + batch_chunk_count: int = 1, + drop_incomplete_batch: bool = False, + parallel: int = -1, + ): + # validation + if batch_size % batch_chunk_count != 0: + raise RuntimeError( + f"batch_size ({batch_size}) must be divisible by " + f"batch_chunk_size ({batch_chunk_count})." + ) + + super().__init__() + + self.database_split = database_split + self.raw_data_loader = raw_data_loader + self.cache_samples = cache_samples + self.train_sampler = train_sampler + self.data_augmentations = data_augmentations + self.model_transforms = model_transforms + + self._batch_size = batch_size + self._batch_chunk_count = batch_chunk_count + self._chunk_size = self._batch_size // self._batch_chunk_count + + self.drop_incomplete_batch = drop_incomplete_batch + self._parallel = parallel # immutable, otherwise would need to call + # the next function again + self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( + parallel + ) + + self.pin_memory = ( + torch.cuda.is_available() + ) # should only be true if GPU available and using it + + # datasets that have been setup() for the current stage + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + + def setup(self, stage: str) -> None: + """Sets up datasets for different tasks on the pipeline. + + This method should setup (load, pre-process, etc) all subsets required + for a particular ``stage`` (fit, validate, test, predict), and keep + them ready to be used on one of the `_dataloader()` functions that are + pertinent for such stage. + + If you have set ``cache_samples``, samples are loaded at this stage and + cached in memory. + + + Parameters + ---------- + + stage + Name of the stage to which the setup is applicable. Can be one of + ``fit``, ``validate``, ``test`` or ``predict``. Each stage + typically uses the following data loaders: + + * ``fit``: uses both train and validation datasets + * ``validate``: uses only the validation dataset + * ``test``: uses only the test dataset + * ``predict``: uses only the test dataset + """ + + def _setup(name, transforms): + if self.cache_samples: + self._datasets[name] = _CachedDataset( + self.database_split[name], self.raw_data_loader, transforms + ) + else: + self._datasets[name] = _DelayedLoadingDataset( + self.database_split[name], self.raw_data_loader, transforms + ) + + if stage == "fit": + _setup("train", self.data_augmentations + self.model_transforms) + _setup("validation", self.model_transforms) + for k in self.database_split: + if k.startswith("monitor-"): + _setup(k, self.model_transforms) + + elif stage == "validate": + _setup("validation", self.model_transforms) + for k in self.database_split: + if k.startswith("monitor-"): + _setup(k, self.model_transforms) + + elif stage == "test": + _setup("test", self.model_transforms) + + elif stage == "predict": + _setup("test", self.model_transforms) + + def train_dataloader(self): + """Returns the train data loader.""" + + return torch.utils.data.DataLoader( + self._datasets["train"], + batch_size=self._chunk_size, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + sampler=self.train_sampler, + **self._dataloader_multiproc, + ) + + def val_dataloader(self): + """Returns the validation data loader(s)""" + + extra_valid = [ + k for k in self.database_split.keys() if k.startswith("monitor-") + ] + + # TODO: do we really need the train sampler here? + validation_loader_opts = { + "batch_size": self._chunk_size, + "shuffle": False, + "drop_last": self.drop_incomplete_batch, + "pin_memory": self.pin_memory, + } + validation_loader_opts.update(self._dataloader_multiproc) + + # TODO: not sure this is the right way to handle multiple validation + # loaders, please check and fix + if not extra_valid: + return torch.utils.data.DataLoader( + self._datasets["validation"], + **validation_loader_opts, + ) + + else: + return [ + torch.utils.data.DataLoader( + self._datasets[k], + **validation_loader_opts, + ) + for k in ["validation"] + extra_valid + ] + + def test_dataloader(self): + """Returns the test data loader(s)""" + + return torch.utils.data.DataLoader( + self._datasets["test"], + batch_size=self._chunk_size, + shuffle=False, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + **self._dataloader_multiproc, + ) + + def predict_dataloader(self): + """Returns the prediction data loader(s)""" + + return self.test_dataloader() diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 15dc32a9..6b373db4 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -2,453 +2,261 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import collections.abc import csv +import importlib.abc import json import logging -import os import pathlib +import typing import torch - -from tqdm import tqdm +import torch.utils.data logger = logging.getLogger(__name__) -class JSONProtocol: - """Generic multi-protocol/subset filelist dataset that yields samples. +class JSONDatabaseSplit( + dict, + typing.Mapping[str, typing.Sequence[typing.Any]], +): + """Defines a loader that understands a database split (train, test, etc) in + JSON format. - To create a new dataset, you need to provide one or more JSON formatted - filelists (one per protocol) with the following contents: + To create a new database split, you need to provide a JSON formatted + dictionary in a file, with contents similar to the following: .. code-block:: json { "subset1": [ [ - "value1", - "value2", - "value3" + "sample1-data1", + "sample1-data2", + "sample1-data3", ], [ - "value4", - "value5", - "value6" + "sample2-data1", + "sample2-data2", + "sample2-data3", ] ], "subset2": [ + [ + "sample42-data1", + "sample42-data2", + "sample42-data3", + ], ] } - Your dataset many contain any number of subsets, but all sample entries - must contain the same number of fields. + Your database split many contain any number of subsets (dictionary keys). + For simplicity, we recommend all sample entries are formatted similarly so + that raw-data-loading is simplified. Use the function + :py:func:`check_database_split_loading` to test raw data loading and fine + tune the dataset split, or its loading. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. Parameters ---------- - protocols : list, dict - Paths to one or more JSON formatted files containing the various - protocols to be recognized by this dataset, or a dictionary, mapping - protocol names to paths (or opened file objects) of CSV files. - Internally, we save a dictionary where keys default to the basename of - paths (list input). - - fieldnames : list, tuple - An iterable over the field names (strings) to assign to each entry in - the JSON file. It should have as many items as fields in each entry of - the JSON file. - - loader : object - A function that receives as input, a context dictionary (with at least - a "protocol" and "subset" keys indicating which protocol and subset are - being served), and a dictionary with ``{fieldname: value}`` entries, - and returns an object with at least 2 attributes: - - * ``key``: which must be a unique string for every sample across - subsets in a protocol, and - * ``data``: which contains the data associated witht this sample + path + Absolute path to a JSON formatted file containing the database split to be + recognized by this object. """ - def __init__(self, protocols, fieldnames): - if isinstance(protocols, dict): - self._protocols = protocols - else: - self._protocols = { - os.path.basename( - str(k).replace("".join(pathlib.Path(k).suffixes), "") - ): k - for k in protocols - } - self.fieldnames = fieldnames - - def check(self, limit=0): - """For each protocol, check if all data can be correctly accessed. - - This function assumes each sample has a ``data`` and a ``key`` - attribute. The ``key`` attribute should be a string, or representable - as such. - - - Parameters - ---------- - - limit : int - Maximum number of samples to check (in each protocol/subset - combination) in this dataset. If set to zero, then check - everything. - + def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): + if isinstance(path, str): + path = pathlib.Path(path) + self.path = path + self.subsets = self._load_split_from_disk() - Returns - ------- + def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]: + """Loads all subsets in a split from its file system representation. - errors : int - Number of errors found - """ - logger.info("Checking dataset...") - errors = 0 - for proto in self._protocols: - logger.info(f"Checking protocol '{proto}'...") - for name, samples in self.subsets(proto).items(): - logger.info(f"Checking subset '{name}'...") - if limit: - logger.info(f"Checking at most first '{limit}' samples...") - samples = samples[:limit] - for pos, sample in enumerate(samples): - try: - sample.data # may trigger data loading - logger.info(f"{sample.key}: OK") - except Exception as e: - logger.error( - f"Found error loading entry {pos} in subset {name} " - f"of protocol {proto} from file " - f"'{self._protocols[proto]}': {e}" - ) - errors += 1 - return errors - - def subsets(self, protocol): - """Returns all subsets in a protocol. - - This method will load JSON information for a given protocol and return - all subsets of the given protocol after converting each entry through - the loader function. - - Parameters - ---------- - - protocol : str - Name of the protocol data to load + This method will load JSON information for the current split and return + all subsets of the given split after converting each entry through the + loader function. Returns ------- subsets : dict - A dictionary mapping subset names to lists of objects (respecting - the ``key``, ``data`` interface). + A dictionary mapping subset names to lists of JSON objects """ - fileobj = self._protocols[protocol] - if isinstance(fileobj, (str, bytes, pathlib.Path)): - if str(fileobj).endswith(".bz2"): - import bz2 - with bz2.open(self._protocols[protocol]) as f: - data = json.load(f) - else: - with open(self._protocols[protocol]) as f: - data = json.load(f) + if str(self.path).endswith(".bz2"): + logger.debug(f"Loading database split from {str(self.path)}...") + with __import__("bz2").open(self.path) as f: + return json.load(f) else: - data = json.load(fileobj) - fileobj.seek(0) + with self.path.open() as f: + return json.load(f) - retval = {} - for subset, samples in data.items(): - logger.info(f"Loading subset {subset} samples.") + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] - retval[subset] = [ - dict(zip(self.fieldnames, k)) - for n, k in enumerate(tqdm(samples)) - ] + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) + + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) - return retval +class CSVDatabaseSplit(collections.abc.Mapping): + """Defines a loader that understands a database split (train, test, etc) in + CSV format. -class CSVDataset: - """Generic multi-subset filelist dataset that yields samples. + To create a new database split, you need to provide one or more CSV + formatted files, each representing a subset of this split, containing the + sample data (one per row). Example: - To create a new dataset, you only need to provide a CSV formatted filelist - using any separator (e.g. comma, space, semi-colon) with the following - information: + Inside the directory ``my-split/``, one can file files ``train.csv``, + ``validation.csv``, and ``test.csv``. Each file has a structure similar to + the following: .. code-block:: text - value1,value2,value3 - value4,value5,value6 + sample1-value1,sample1-value2,sample1-value3 + sample2-value1,sample2-value2,sample2-value3 ... - Notice that all rows must have the same number of entries. + Each file in the provided directory defines the subset name on the split. + So, the file ``train.csv`` will contain the data from the ``train`` subset, + and so on. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. + Parameters ---------- - subsets : list, dict - Paths to one or more CSV formatted files containing the various subsets - to be recognized by this dataset, or a dictionary, mapping subset names - to paths (or opened file objects) of CSV files. Internally, we save a - dictionary where keys default to the basename of paths (list input). - - fieldnames : list, tuple - An iterable over the field names (strings) to assign to each column in - the CSV file. It should have as many items as fields in each row of - the CSV file(s). - - loader : object - A function that receives as input, a context dictionary (with, at - least, a "subset" key indicating which subset is being served), and a - dictionary with ``{key: path}`` entries, and returns a dictionary with - the loaded data. + directory + Absolute path to a directory containing the database split layed down + as a set of CSV files, one per subset. """ - def __init__(self, subsets, fieldnames, loader): - if isinstance(subsets, dict): - self._subsets = subsets - else: - self._subsets = { - os.path.basename( - str(k).replace("".join(pathlib.Path(k).suffixes), "") - ): k - for k in subsets - } - self.fieldnames = fieldnames - self._loader = loader - - def check(self, limit=0): - """For each subset, check if all data can be correctly accessed. - - This function assumes each sample has a ``data`` and a ``key`` - attribute. The ``key`` attribute should be a string, or representable - as such. - - - Parameters - ---------- - - limit : int - Maximum number of samples to check (in each protocol/subset - combination) in this dataset. If set to zero, then check - everything. + def __init__( + self, directory: pathlib.Path | str | importlib.abc.Traversable + ): + if isinstance(directory, str): + directory = pathlib.Path(directory) + assert ( + directory.is_dir() + ), f"`{str(directory)}` is not a valid directory" + self.directory = directory + self.subsets = self._load_split_from_disk() + def _load_split_from_disk( + self, + ) -> dict[str, list[typing.Any]]: + """Loads all subsets in a split from its file system representation. - Returns - ------- + This method will load CSV information for the current split and return all + subsets of the given split after converting each entry through the + loader function. - errors : int - Number of errors found - """ - logger.info("Checking dataset...") - errors = 0 - for name in self._subsets.keys(): - logger.info(f"Checking subset '{name}'...") - samples = self.samples(name) - if limit: - logger.info(f"Checking at most first '{limit}' samples...") - samples = samples[:limit] - for pos, sample in enumerate(samples): - try: - sample.data # may trigger data loading - logger.info(f"{sample.key}: OK") - except Exception as e: - logger.error( - f"Found error loading entry {pos} in subset {name} " - f"from file '{self._subsets[name]}': {e}" - ) - errors += 1 - return errors - - def subsets(self): - """Returns all available subsets at once. Returns ------- subsets : dict - A dictionary mapping subset names to lists of objects (respecting - the ``key``, ``data`` interface). - """ - return {k: self.samples(k) for k in self._subsets.keys()} - - def samples(self, subset): - """Returns all samples in a subset. - - This method will load CSV information for a given subset and return - all samples of the given subset after passing each entry through the - loading function. - - - Parameters - ---------- - - subset : str - Name of the subset data to load - - - Returns - ------- - - subset : list - A lists of objects (respecting the ``key``, ``data`` interface). + A dictionary mapping subset names to lists of JSON objects """ - fileobj = self._subsets[subset] - if isinstance(fileobj, (str, bytes, pathlib.Path)): - with open(self._subsets[subset], newline="") as f: - cf = csv.reader(f) - samples = [k for k in cf] - else: - cf = csv.reader(fileobj) - samples = [k for k in cf] - fileobj.seek(0) - return [ - self._loader( - dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) - ) - for n, k in enumerate(samples) - ] - - -class CachedDataset(torch.utils.data.Dataset): - def __init__( - self, json_protocol, protocol, subset, raw_data_loader, transforms - ): - self.json_protocol = json_protocol - self.subset = subset - self.raw_data_loader = raw_data_loader - self.transforms = transforms - - self._samples = json_protocol.subsets(protocol)[self.subset] - # Dict entry with relative path to files, used during prediction - for s in self._samples: - s["name"] = s["data"] - - logger.info(f"Caching {self.subset} samples") - for sample in tqdm(self._samples): - sample["data"] = self.raw_data_loader(sample["data"]) - - def __getitem__(self, idx): - sample = self._samples[idx].copy() - sample["data"] = self.transforms(sample["data"]) - return sample - - def __len__(self): - return len(self._samples) - - -class RuntimeDataset(torch.utils.data.Dataset): - def __init__( - self, json_protocol, protocol, subset, raw_data_loader, transforms - ): - self.json_protocol = json_protocol - self.subset = subset - self.raw_data_loader = raw_data_loader - self.transforms = transforms + retval = {} + for subset in self.directory.iterdir(): + if str(subset).endswith(".csv.bz2"): + logger.debug(f"Loading database split from {subset}...") + with __import__("bz2").open(subset) as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv.bz2")]] = [ + k for k in reader + ] + elif str(subset).endswith(".csv"): + with subset.open() as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv")]] = [k for k in reader] + else: + logger.debug( + f"Ignoring file {subset} in CSVDatabaseSplit readout" + ) + return retval - self._samples = json_protocol.subsets(protocol)[self.subset] - # Dict entry with relative path to files - for s in self._samples: - s["name"] = s["data"] + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] - def __getitem__(self, idx): - sample = self._samples[idx].copy() - sample["data"] = self.transforms(self.raw_data_loader(sample["data"])) - return sample + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) - def __len__(self): - return len(self._samples) + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) -def get_samples_weights(dataset): - """Compute the weights of all the samples of the dataset to balance it - using the sampler of the dataloader. +def check_database_split_loading( + database_split: typing.Mapping[str, typing.Sequence[typing.Any]], + loader: typing.Callable[[typing.Any], torch.Tensor], + limit: int = 0, +) -> int: + """For each subset in the split, check if all data can be correctly loaded + using the provided loader function. - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the weights to balance each class in the dataset and the - datasets themselves if we have a ConcatDataset. + This function will return the number of errors loading samples, and will + log more detailed information to the logging stream. Parameters ---------- - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported + database_split + A mapping that, contains the database split. Each key represents the + name of a subset in the split. Each value is a (potentially complex) + object that represents a single sample. + + loader + A callable that transforms sample entries in the database split + into :py:class:`torch.Tensor` objects that can be used for training + or inference. + + limit + Maximum number of samples to check (in each split/subset + combination) in this dataset. If set to zero, then check + everything. Returns ------- - samples_weights : :py:class:`torch.Tensor` - the weights for all the samples in the dataset given as input + errors + Number of errors found """ - samples_weights = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - # Weighting only for binary labels - if isinstance(ds._samples[0]["label"], int): - # Groundtruth - targets = [] - for s in ds._samples: - targets.append(s["label"]) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] + logger.info( + "Checking if can load all samples in all subsets of this split..." + ) + errors = 0 + for subset in database_split.keys(): + samples = subset if not limit else subset[:limit] + for pos, sample in enumerate(samples): + try: + data = loader(sample) + assert isinstance(data, torch.Tensor) + except Exception as e: + logger.info( + f"Found error loading entry {pos} in subset `{subset}`: {e}" ) - - weight = 1.0 / class_sample_count.float() - - samples_weights.append( - torch.tensor([weight[t] for t in targets]) - ) - - else: - # We only weight to sample equally from each dataset - samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) - - # Concatenate sample weights from all the datasets - samples_weights = torch.cat(samples_weights) - - else: - # Weighting only for binary labels - if isinstance(dataset._samples[0]["label"], int): - # Groundtruth - targets = [] - for s in dataset._samples: - targets.append(s["label"]) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights = torch.tensor([weight[t] for t in targets]) - - else: - # Equal weights for non-binary labels - samples_weights = torch.ones(len(dataset._samples)) - - return samples_weights + errors += 1 + return errors def get_positive_weights(dataset): diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index 54eb6326..1645962e 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems. * Reference: [MONTGOMERY-SHENZHEN-2014]_ * Original resolution (height x width or width x height): 3000 x 3000 or less * Split reference: none -* Protocol ``default``: - * Training samples: 64% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels) """ import importlib.resources -import os - -from clapper.logging import setup - -from ...utils.rc import load_rc -from ..loader import load_pil_baw - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - _protocols = [ importlib.resources.files(__name__).joinpath("default.json.bz2"), @@ -43,11 +32,3 @@ _protocols = [ importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), ] - -_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) - - -def _raw_data_loader(img_path): - raw_data = load_pil_baw(os.path.join(_datadir, img_path)) - - return raw_data diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index b5fb23f5..c9068112 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -2,27 +2,34 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for TB detection (default protocol) +"""Shenzhen datamodule for computer-aided diagnosis (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details +See :py:mod:`ptbench.data.shenzhen` for dataset details. + +This configuration: +* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms` +* augmentations: elastic deformation (probability = 80%) +* output image resolution: 512x512 pixels """ -from clapper.logging import setup +import importlib.resources +from ..datamodule import CachingDataModule +from ..dataset import JSONDatabaseSplit from ..transforms import ElasticDeformation -from .utils import ShenzhenDataModule - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - -protocol_name = "default" - -augmentation_transforms = [ElasticDeformation(p=0.8)] - -datamodule = ShenzhenDataModule( - protocol="default", - model_transforms=[], - augmentation_transforms=augmentation_transforms, +from .loader import raw_data_loader + +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__).joinpath("default.json.bz2") + ), + raw_data_loader=raw_data_loader, + cache_samples=False, + # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, + data_augmentations=[ElasticDeformation(p=0.8)], + # model_transforms = [], + # batch_size = 1, + # batch_chunk_count = 1, + # drop_incomplete_batch = False, + # parallel = -1, ) diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py new file mode 100644 index 00000000..6578fc8f --- /dev/null +++ b/src/ptbench/data/shenzhen/loader.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for computer-aided diagnosis. + +The standard digital image database for Tuberculosis is created by the National +Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s +Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from +out-patient clinics, and were captured as part of the daily routine using +Philips DR Digital Diagnose systems. + +* Reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* Split reference: none +* Protocol ``default``: + + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) +""" + +import os +import typing + +import torch.nn +import torchvision.transforms + +from ...utils.rc import load_rc +from ..loader import load_pil_baw +from ..transforms import RemoveBlackBorders + +_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) +"""This variable contains the base directory where the database raw data is +stored.""" + +_transform = torchvision.transforms.Compose( + [ + RemoveBlackBorders(), + torchvision.transforms.Resize(512), + torchvision.transforms.CenterCrop(512), + torchvision.transforms.ToTensor(), + ] +) +"""Transforms that are always applied to the loaded raw images.""" + + +def raw_data_loader( + sample: tuple[str, int] +) -> tuple[torch.Tensor, typing.Mapping]: + """Loads a single image sample from the disk. + + Parameters + ---------- + + img_path + The path suffix, within the dataset root folder, where to find the + image to be loaded. + + + Returns + ------- + + image + A PIL image in grayscale mode + """ + tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0]))) + return tensor, dict(label=sample[1]) # type: ignore[arg-type] diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py deleted file mode 100644 index 1521b674..00000000 --- a/src/ptbench/data/shenzhen/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the -National Library of Medicine, Maryland, USA in collaboration with Shenzhen -No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. -The Chest X-rays are from out-patient clinics, and were captured as part of -the daily routine using Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" -from clapper.logging import setup -from torchvision import transforms - -from ..base_datamodule import BaseDataModule -from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset -from ..shenzhen import _protocols, _raw_data_loader -from ..transforms import RemoveBlackBorders - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class ShenzhenDataModule(BaseDataModule): - def __init__( - self, - protocol="default", - model_transforms=[], - augmentation_transforms=[], - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - cache_samples=False, - parallel=-1, - ): - super().__init__( - batch_size=batch_size, - drop_incomplete_batch=drop_incomplete_batch, - batch_chunk_count=batch_chunk_count, - parallel=parallel, - ) - - self._cache_samples = cache_samples - self._has_setup_fit = False - self._has_setup_predict = False - self._protocol = protocol - - self.raw_data_transforms = [ - RemoveBlackBorders(), - transforms.Resize(512), - transforms.CenterCrop(512), - transforms.ToTensor(), - ] - - self.model_transforms = model_transforms - - self.augmentation_transforms = augmentation_transforms - - def setup(self, stage: str): - json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - ) - - if self._cache_samples: - dataset = CachedDataset - else: - dataset = RuntimeDataset - - if not self._has_setup_fit and stage == "fit": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=True), - ) - - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_fit = True - - if not self._has_setup_predict and stage == "predict": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_predict = True diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index 6d3c17d2..c85c3e9d 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -58,6 +58,8 @@ class RemoveBlackBorders: class ElasticDeformation: """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. + TODO: needs to be converted into a torch.nn.Module to become scriptable! + Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 """ -- GitLab