diff --git a/doc/references.rst b/doc/references.rst index 5a349fa9dd960aa393f8f9d5dfd6f696676317e6..056707406d9e3dbf5d73fe5bc03cdd7f49701cbb 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -1,5 +1,6 @@ - -.. coding=utf-8 +.. SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +.. +.. SPDX-License-Identifier: GPL-3.0-or-later ============ References diff --git a/pyproject.toml b/pyproject.toml index bb93be327219f1e58eadc46317074d91eb3235e8..7ce5443548baaf829fa3c877ef8a35f805554ed4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,33 +15,33 @@ dynamic = ["readme"] license = { text = "GNU General Public License v3 (GPLv3)" } authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }] maintainers = [ - { name = "Andre Anjos", email = "andre.anjos@idiap.ch" }, - { name = "Daniel Carron", email = "daniel.carron@idiap.ch" }, + { name = "Andre Anjos", email = "andre.anjos@idiap.ch" }, + { name = "Daniel Carron", email = "daniel.carron@idiap.ch" }, ] classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Topic :: Software Development :: Libraries :: Python Modules", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ - "clapper", - "click", - "numpy", - "pandas", - "scipy", - "scikit-learn", - "tqdm", - "psutil", - "tabulate", - "matplotlib", - "pillow", - "torch>=1.8", - "torchvision>=0.10", - "lightning>=2.0.3", - "tensorboard", + "clapper", + "click", + "numpy", + "pandas", + "scipy", + "scikit-learn", + "tqdm", + "psutil", + "tabulate", + "matplotlib", + "pillow", + "torch>=1.8", + "torchvision>=0.10", + "lightning>=2.0.3", + "tensorboard", ] [project.urls] @@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases" [project.optional-dependencies] qa = ["pre-commit"] doc = [ - "sphinx", - "furo", - "sphinx-autodoc-typehints", - "auto-intersphinx", - "sphinx-copybutton", - "sphinx-inline-tabs", - "sphinx-click", + "sphinx", + "furo", + "sphinx-autodoc-typehints", + "auto-intersphinx", + "sphinx-copybutton", + "sphinx-inline-tabs", + "sphinx-click", ] test = ["pytest", "pytest-cov", "coverage"] diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index cf8bfd35aa10ad3493be9483ba0347fc8ebb7da1..2361b886d500fee740e456f2505a36da4fdaf4e3 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -6,19 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import SGD from ...models.alexnet import Alexnet -# config +# optimizer +optimizer = SGD optimizer_configs = {"lr": 0.01, "momentum": 0.1} -# optimizer -optimizer = "SGD" # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Alexnet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index 1d196be6f79ea5c70987c1d1a66eaf32e8e7ca4c..0dc7e5d67d007cf5e7e358e7fa75243a47047c4b 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -6,19 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import SGD from ...models.alexnet import Alexnet -# config -optimizer_configs = {"lr": 0.001, "momentum": 0.1} - # optimizer -optimizer = "SGD" +optimizer = SGD +optimizer_configs = {"lr": 0.01, "momentum": 0.1} + # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Alexnet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=True, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 6759490854fd05ac8d8d3a9eec5b1494e0cfb0f2..5d612b2a18146ba306b419a808936e9c3c7042f7 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -6,20 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam from ...models.densenet import Densenet -# config -optimizer_configs = {"lr": 0.0001} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 0.0001} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Densenet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py index b018a52203061b847cdae9f09b5edfa713930302..f8908fdb1e87a62df41dca0ecb75ff1fc79b1012 100644 --- a/src/ptbench/configs/models/densenet_pretrained.py +++ b/src/ptbench/configs/models/densenet_pretrained.py @@ -6,20 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam from ...models.densenet import Densenet -# config -optimizer_configs = {"lr": 0.01} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 0.0001} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Densenet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=True, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -11,20 +11,16 @@ Screening and Visualization". Reference: [PASA-2019]_ """ -from torch import empty from torch.nn import BCEWithLogitsLoss - -from ...models.pasa import PASA - -# config -optimizer_configs = {"lr": 8e-5} - -# optimizer -optimizer = "Adam" - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - -# model -model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) +from torch.optim import Adam + +from ...data.transforms import ElasticDeformation +from ...models.pasa import Pasa + +model = Pasa( + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=Adam, + optimizer_arguments=dict(lr=8e-5), + augmentation_transforms=[ElasticDeformation(p=0.8)], +) diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py deleted file mode 100644 index 8377c66328241f549dbb2f10946f9cc973ef7a6f..0000000000000000000000000000000000000000 --- a/src/ptbench/data/base_datamodule.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import multiprocessing -import sys - -import lightning as pl -import torch - -from clapper.logging import setup -from torch.utils.data import DataLoader, WeightedRandomSampler -from torchvision import transforms - -from .dataset import get_samples_weights - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class BaseDataModule(pl.LightningDataModule): - def __init__( - self, - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - parallel=-1, - ): - super().__init__() - - self.batch_size = batch_size - self.batch_chunk_count = batch_chunk_count - - self.train_dataset = None - self.validation_dataset = None - self.extra_validation_datasets = None - - self._raw_data_transforms = [] - self._model_transforms = [] - self._augmentation_transforms = [] - - self.drop_incomplete_batch = drop_incomplete_batch - self.pin_memory = ( - torch.cuda.is_available() - ) # should only be true if GPU available and using it - - self.parallel = parallel - - def setup(self, stage: str): - # Implemented by user - # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset - raise NotImplementedError - - def train_dataloader(self): - train_samples_weights = get_samples_weights(self.train_dataset) - - multiproc_kwargs = self._setup_multiproc(self.parallel) - train_sampler = WeightedRandomSampler( - train_samples_weights, len(train_samples_weights), replacement=True - ) - - return DataLoader( - self.train_dataset, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - drop_last=self.drop_incomplete_batch, - pin_memory=self.pin_memory, - sampler=train_sampler, - **multiproc_kwargs, - ) - - def val_dataloader(self): - loaders_dict = {} - - multiproc_kwargs = self._setup_multiproc(self.parallel) - - val_loader = DataLoader( - dataset=self.validation_dataset, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - shuffle=False, - drop_last=False, - pin_memory=self.pin_memory, - **multiproc_kwargs, - ) - - loaders_dict["validation_loader"] = val_loader - - if self.extra_validation_datasets is not None: - for set_idx, extra_set in enumerate(self.extra_validation_datasets): - extra_val_loader = DataLoader( - dataset=extra_set, - batch_size=self._compute_chunk_size( - self.batch_size, self.batch_chunk_count - ), - shuffle=False, - drop_last=False, - pin_memory=self.pin_memory, - **multiproc_kwargs, - ) - - loaders_dict[ - f"extra_validation_loader{set_idx}" - ] = extra_val_loader - - return loaders_dict - - def predict_dataloader(self): - loaders_dict = {} - - loaders_dict["train_loader"] = self.train_dataloader() - for k, v in self.val_dataloader().items(): - loaders_dict[k] = v - - return loaders_dict - - def update_module_properties(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - def _compute_chunk_size(self, batch_size, chunk_count): - batch_chunk_size = batch_size - if batch_size % chunk_count != 0: - # batch_size must be divisible by batch_chunk_count. - raise RuntimeError( - f"--batch-size ({batch_size}) must be divisible by " - f"--batch-chunk-size ({chunk_count})." - ) - else: - batch_chunk_size = batch_size // chunk_count - - return batch_chunk_size - - def _setup_multiproc(self, parallel): - multiproc_kwargs = dict() - if parallel < 0: - multiproc_kwargs["num_workers"] = 0 - else: - multiproc_kwargs["num_workers"] = ( - parallel or multiprocessing.cpu_count() - ) - - if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": - multiproc_kwargs[ - "multiprocessing_context" - ] = multiprocessing.get_context("spawn") - - return multiproc_kwargs - - def _build_transforms(self, is_train): - all_transforms = self.raw_data_transforms + self.model_transforms - - if is_train: - all_transforms = all_transforms + self.augmentation_transforms - - # all_transforms.append(transforms.ToTensor()) - - return transforms.Compose(all_transforms) - - @property - def raw_data_transforms(self): - return self._raw_data_transforms - - @raw_data_transforms.setter - def raw_data_transforms(self, transforms): - self._raw_data_transforms = transforms - - @property - def model_transforms(self): - return self._model_transforms - - @model_transforms.setter - def model_transforms(self, transforms): - self._model_transforms = transforms - - @property - def augmentation_transforms(self): - return self._augmentation_transforms - - @augmentation_transforms.setter - def augmentation_transforms(self, transforms): - self._augmentation_transforms = transforms - - -def get_dataset_from_module(module, stage, **module_args): - """Instantiates a DataModule and retrieves the corresponding dataset. - - Useful when combining multiple datasets. - """ - module_instance = module(**module_args) - module_instance.prepare_data() - module_instance.setup(stage=stage) - dataset = module_instance.dataset - - return dataset diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..abcf11d79286ecdb16cc374fb920f68418ec603a --- /dev/null +++ b/src/ptbench/data/datamodule.py @@ -0,0 +1,722 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import collections +import logging +import multiprocessing +import sys +import typing + +import lightning +import torch +import torch.backends +import torch.utils.data +import torchvision.transforms +import tqdm + +from .typing import ( + DatabaseSplit, + DataLoader, + Dataset, + RawDataLoader, + Sample, + Transform, +) + +logger = logging.getLogger(__name__) + + +def _setup_dataloader_multiproc_parameters( + parallel: int, +) -> dict[str, typing.Any]: + """Returns a dictionary containing pytorch arguments to be used in data + loaders. + + It sets the parameter ``num_workers`` to match the expected pytorch + representation. For macOS machines, it also sets the + ``multiprocessing_context`` to use ``spawn`` instead of the default. + + The mapping between the command-line interface ``parallel`` setting works + like this: + + .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation + :widths: 15 15 70 + :header-rows: 1 + + * - CLI ``parallel`` + - :py:class:`torch.utils.data.DataLoader` ``kwargs`` + - Comments + * - ``<0`` + - 0 + - Disables multiprocessing entirely, executes everything within the + same processing context + * - ``0`` + - :py:func:`multiprocessing.cpu_count` + - Runs mini-batch data loading on as many external processes as CPUs + available in the current machine + * - ``>=1`` + - ``parallel`` + - Runs mini-batch data loading on as many external processes as set on + ``parallel`` + """ + + retval: dict[str, typing.Any] = dict() + if parallel < 0: + retval["num_workers"] = 0 + else: + retval["num_workers"] = parallel or multiprocessing.cpu_count() + + if retval["num_workers"] > 0 and sys.platform == "darwin": + retval["multiprocessing_context"] = multiprocessing.get_context("spawn") + + return retval + + +class _DelayedLoadingDataset(Dataset): + """A list that loads its samples on demand. + + This list mimics a pytorch Dataset, except raw data loading is done + on-the-fly, as the samples are requested through the bracket operator. + + + Parameters + ---------- + + split + An iterable containing the raw dataset samples loaded from the database + splits. + + loader + An object instance that can load samples and labels from storage. + + transforms + A set of transforms that should be applied on-the-fly for this dataset, + to fit the output of the raw-data-loader to the model of interest. + """ + + def __init__( + self, + split: typing.Sequence[typing.Any], + loader: RawDataLoader, + transforms: typing.Sequence[Transform] = [], + ): + self.split = split + self.loader = loader + self.transform = torchvision.transforms.Compose(transforms) + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + return [self.loader.label(k) for k in self.split] + + def __getitem__(self, key: int) -> Sample: + tensor, metadata = self.loader.sample(self.split[key]) + return self.transform(tensor), metadata + + def __len__(self): + return len(self.split) + + def __iter__(self): + for x in range(len(self)): + yield self[x] + + +class _CachedDataset(Dataset): + """Basically, a list of preloaded samples. + + This dataset will load all samples from the split during construction + instead of delaying that to the indexing. Beyong raw-data-loading, + ``transforms`` given upon construction contribute to the cached samples. + + + Parameters + ---------- + + split + An iterable containing the raw dataset samples loaded from the database + splits. + + loader + An object instance that can load samples and labels from storage. + + parallel + Use multiprocessing for data loading: if set to -1 (default), disables + multiprocessing data loading. Set to 0 to enable as many data loading + instances as processing cores as available in the system. Set to >= 1 + to enable that many multiprocessing instances for data loading. + + transforms + A set of transforms that should be applied to the cached samples for + this dataset, to fit the output of the raw-data-loader to the model of + interest. + """ + + def __init__( + self, + split: typing.Sequence[typing.Any], + loader: RawDataLoader, + parallel: int = -1, + transforms: typing.Sequence[Transform] = [], + ): + self.transform = torchvision.transforms.Compose(transforms) + + if parallel < 0: + self.data = [ + loader.sample(k) for k in tqdm.tqdm(split, unit="sample") + ] + else: + instances = parallel or multiprocessing.cpu_count() + logger.info(f"Caching dataset using {instances} processes...") + with multiprocessing.Pool(instances) as p: + self.data = list( + tqdm.tqdm(p.imap(loader.sample, split), total=len(split)) + ) + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + return [k[1]["label"] for k in self.data] + + def __getitem__(self, key: int) -> Sample: + tensor, metadata = self.data[key] + return self.transform(tensor), metadata + + def __len__(self): + return len(self.data) + + def __iter__(self): + for x in range(len(self)): + yield self[x] + + +def _make_balanced_random_sampler( + dataset: Dataset, + target: str = "label", +) -> torch.utils.data.WeightedRandomSampler: + """Generates a pytorch sampler that samples according to class + probabilities. + + This function takes as input a torch Dataset, and computes the weights to + balance each class in the dataset, and the datasets themselves if one + passes a :py:class:`torch.utils.data.ConcatDataset`. + + In this implementation, we balance **both** class and dataset-origin + probabilities, what you expect for a truly *equitable* random sampler. + + Take this example for illustration: + + * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 + * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 + + So: + + | Dataset | Target | Samples | Weight | Normalised weight | + +---------+--------+---------+--------+-------------------+ + | 1 | 0 | 9 | 1/9 | 1/36 | + | 1 | 1 | 1 | 1/1 | 1/4 | + | 2 | 0 | 3 | 1/3 | 1/12 | + | 2 | 1 | 3 | 1/3 | 1/12 | + + Legend: + + * Weight: the weights computed by this method + * Normalised weight: the weight per sample used by the random sampler, + after normalising the weights by the sum of all weights in the + concatenated dataset, such that the sum of all normalized weights times + the number of samples is 1. + + The properties of this algorithm are as follows: + + 1. The probability of picking a sample from any target is the same (0.5 in + this case). To verify this, notice that the probability of picking a + sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. + 2. The probability of picking a sample with ``target=0`` from Dataset 2 is + 3 times higher than those from Dataset 1. As there are 3 times less + samples in Dataset 2 with ``target=0``, this makes choosing samples from + Dataset 1 proportionally less likely. + 3. The probability of picking a sample with ``target=1`` from Dataset 2 is + 3 times lower than those from Dataset 1. As there are 3 times less + samples in Dataset 1 with ``target=1``, this makes choosing samples from + Dataset 2 proportionally less likely. + + This function assumes targets are stored on a dictionary entry named + ``target`` inside the metadata information for the :py:type:``Sample``, and + that its value is integer. + + We then instantiate a pytorch sampler using the inverse probabilities (the + more samples of a class, the less likely it becomes to be sampled. + + + Parameters + ---------- + + dataset + An instance of torch Dataset. + :py:class:`torch.utils.data.ConcatDataset` are supported. + + target + The name of a metadata key pointing to an integer property that allows + balancing the dataset. + + + Returns + ------- + + sampler + A sampler, to be used in a dataloader equipped with the same dataset + used to calculate the relative sample weights. + + + Raises + ------ + + RuntimeError + If requested to balance a dataset (single, not-concatenated) without an + existing target. + """ + + def _calculate_weights(targets: list[int]) -> list[float]: + counts = collections.Counter(targets) + weights = {k: 1.0 / v for k, v in counts.items()} + return [weights[k] for k in targets] + + if isinstance(dataset, torch.utils.data.ConcatDataset): + # There are two possible cases: targets/no-targets + metadata_example = dataset.datasets[0][0][1] + if target in metadata_example and isinstance( + metadata_example[target], int + ): + # there are integer targets, let's balance with those + logger.info( + f"Balancing sample selection probabilities **and** " + f"concatenated-datasets using metadata targets `{target}`" + ) + targets = [ + k + for ds in dataset.datasets + for k in typing.cast(Dataset, ds).labels() + ] + weights = _calculate_weights(targets) + else: + logger.warning( + f"Balancing samples **and** concatenated-datasets " + f"WITHOUT metadata targets (`{target}` not available)" + ) + weights = [ + k + for ds in dataset.datasets + for k in len(typing.cast(typing.Sized, ds)) + * [1.0 / len(typing.cast(typing.Sized, ds))] + ] + + pass + + else: + metadata_example = dataset[0][1] + if target in metadata_example and isinstance( + metadata_example[target], int + ): + logger.info( + f"Balancing samples from dataset using metadata " + f"targets `{target}`" + ) + weights = _calculate_weights(dataset.labels()) + else: + raise RuntimeError( + f"Cannot balance samples without metadata targets `{target}`" + ) + + return torch.utils.data.WeightedRandomSampler( + weights, len(weights), replacement=True + ) + + +class CachingDataModule(lightning.LightningDataModule): + """A conveninent data module with CSV or JSON protocol loading, mini- + batching, parallelisation and caching, all in one. + + Instances of this class load data-split (a.k.a. protocol) definitions for a + database, and can load the data from the disk. An optional caching + mechanism stores the data at associated CPU memory, which can improve data + serving while training and evaluating models. + + This datamodule defines basic operations to handle data loading and + mini-batch handling within this package's framework. It can return + :py:class:`torch.utils.data.DataLoader` objects for training, validation, + prediction and testing conditions. Parallelisation is handled by a simple + input flag. + + Users must implement the basic :py:meth:`setup` function, which is + parameterised by a single string enumeration containing: ``fit``, + ``validate``, ``test``, or ``predict``. + + + Parameters + ---------- + + database_split + A dictionary that contains string keys representing subset names, and + values that are iterables over sample representations (potentially on + disk). These objects are passed to the ``sample_loader`` for loading + the sample data (and metadata) in memory. The objects represented may + be of any format (e.g. list, dictionary, etc), for as long as the + ``sample_loader`` can properly handle it. To check the split and the + loader function works correctly, you may use + :py:func:`..dataset.check_database_split_loading`. As is, this class + expects at least one entry called ``train`` to exist in the input + dictionary. Optional entries are ``validation``, and ``test``. Entries + named ``monitor-...`` will be considered extra subsets that do not + influence any early stop criteria during training, and are just + monitored beyond the ``validation`` dataset. + + loader + An object instance that can load samples and labels from storage. + + cache_samples + If set, then issue raw data loading during ``prepare_data()``, and + serves samples from CPU memory. Otherwise, loads samples from disk on + demand. Running from CPU memory will offer increased speeds in exchange + for CPU memory. Sufficient CPU memory must be available before you set + this attribute to ``True``. It is typicall useful for relatively small + datasets. + + balance_sampler_by_class + If set, then modifies the random sampler used during training and + validation to balance sample picking probability, making sample + across classes **and** datasets equitable. + + model_transforms + A list of transforms (torch modules) that will be applied after + raw-data-loading, and just before data is fed into the model or + eventual data-augmentation transformations for all data loaders + produced by this data module. This part of the pipeline receives data + as output by the raw-data-loader, or model-related transforms (e.g. + resize adaptions), if any is specified. + + batch_size + Number of samples in every **training** batch (this parameter affects + memory requirements for the network). If the number of samples in the + batch is larger than the total number of samples available for + training, this value is truncated. If this number is smaller, then + batches of the specified size are created and fed to the network until + there are no more new samples to feed (epoch is finished). If the + total number of training samples is not a multiple of the batch-size, + the last batch will be smaller than the first, unless + ``drop_incomplete_batch`` is set to ``true``, in which case this batch + is not used. + + batch_chunk_count + Number of chunks in every batch (this parameter affects memory + requirements for the network). The number of samples loaded for every + iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` + needs to be divisible by ``batch_chunk_count``, otherwise an error will + be raised. This parameter is used to reduce number of samples loaded in + each iteration, in order to reduce the memory usage in exchange for + processing time (more iterations). This is specially interesting whe + one is running with GPUs with limited RAM. The default of 1 forces the + whole batch to be processed at once. Otherwise the batch is broken into + batch-chunk-count pieces, and gradients are accumulated to complete + each batch. + + drop_incomplete_batch + If set, then may drop the last batch in an epoch, in case it is + incomplete. If you set this option, you should also consider + increasing the total number of epochs of training, as the total number + of training steps may be reduced. + + parallel + Use multiprocessing for data loading: if set to -1 (default), disables + multiprocessing data loading. Set to 0 to enable as many data loading + instances as processing cores as available in the system. Set to >= 1 + to enable that many multiprocessing instances for data loading. + """ + + DatasetDictionary = dict[str, Dataset] + + def __init__( + self, + database_split: DatabaseSplit, + raw_data_loader: RawDataLoader, + cache_samples: bool = False, + balance_sampler_by_class: bool = False, + model_transforms: list[Transform] = [], + batch_size: int = 1, + batch_chunk_count: int = 1, + drop_incomplete_batch: bool = False, + parallel: int = -1, + ): + super().__init__() + + self.set_chunk_size(batch_size, batch_chunk_count) + + self.database_split = database_split + self.raw_data_loader = raw_data_loader + self.cache_samples = cache_samples + self._train_sampler = None + self.balance_sampler_by_class = balance_sampler_by_class + self.model_transforms = model_transforms + + self.drop_incomplete_batch = drop_incomplete_batch + self.parallel = parallel # immutable, otherwise would need to call + + self.pin_memory = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) # should only be true if GPU available and using it + + # datasets that have been setup() for the current stage + self._datasets: CachingDataModule.DatasetDictionary = {} + + @property + def parallel(self) -> int: + """Whether to use multiprocessing for data loading. + + Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as + many data loading instances as processing cores as available in + the system. Set to >= 1 to enable that many multiprocessing + instances for data loading. + """ + return self._parallel + + @parallel.setter + def parallel(self, value: int) -> None: + self._parallel = value + self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( + value + ) + # datasets that have been setup() for the current stage + self._datasets = {} + + @property + def balance_sampler_by_class(self): + """Whether to balance samples across labels/datasets. + + If set, then modifies the random sampler used during training + and validation to balance sample picking probability, making + sample across classes **and** datasets equitable. + """ + return self._train_sampler is not None + + @balance_sampler_by_class.setter + def balance_sampler_by_class(self, value: bool): + if value: + if "train" not in self._datasets: + self._setup_dataset("train") + self._train_sampler = _make_balanced_random_sampler( + self._datasets["train"] + ) + else: + self._train_sampler = None + + def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: + """Coherently sets the batch-chunk-size after validation. + + Parameters + ---------- + + batch_size + Number of samples in every **training** batch (this parameter affects + memory requirements for the network). If the number of samples in the + batch is larger than the total number of samples available for + training, this value is truncated. If this number is smaller, then + batches of the specified size are created and fed to the network until + there are no more new samples to feed (epoch is finished). If the + total number of training samples is not a multiple of the batch-size, + the last batch will be smaller than the first, unless + ``drop_incomplete_batch`` is set to ``true``, in which case this batch + is not used. + + batch_chunk_count + Number of chunks in every batch (this parameter affects memory + requirements for the network). The number of samples loaded for every + iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` + needs to be divisible by ``batch_chunk_count``, otherwise an error will + be raised. This parameter is used to reduce number of samples loaded in + each iteration, in order to reduce the memory usage in exchange for + processing time (more iterations). This is specially interesting whe + one is running with GPUs with limited RAM. The default of 1 forces the + whole batch to be processed at once. Otherwise the batch is broken into + batch-chunk-count pieces, and gradients are accumulated to complete + each batch. + """ + + # validation + if batch_size % batch_chunk_count != 0: + raise RuntimeError( + f"batch_size ({batch_size}) must be divisible by " + f"batch_chunk_size ({batch_chunk_count})." + ) + + self._batch_size = batch_size + self._batch_chunk_count = batch_chunk_count + self._chunk_size = self._batch_size // self._batch_chunk_count + + def _setup_dataset(self, name: str) -> None: + """Sets-up a single dataset from the input data split. + + Parameters + ---------- + + name + Name of the dataset to setup. + """ + + if name in self._datasets: + logger.info( + f"Dataset `{name}` is already setup. " + f"Not re-instantiating it." + ) + return + if self.cache_samples: + logger.info( + f"Loading dataset:`{name}` into memory (caching)." + f" Trade-off: CPU RAM: more | Disk: less" + ) + self._datasets[name] = _CachedDataset( + self.database_split[name], + self.raw_data_loader, + self.parallel, + self.model_transforms, + ) + else: + logger.info( + f"Loading dataset:`{name}` without caching." + f" Trade-off: CPU RAM: less | Disk: more" + ) + self._datasets[name] = _DelayedLoadingDataset( + self.database_split[name], + self.raw_data_loader, + self.model_transforms, + ) + + def _val_dataset_keys(self) -> list[str]: + """Returns list of validation dataset names.""" + return ["validation"] + [ + k for k in self.database_split.keys() if k.startswith("monitor-") + ] + + def setup(self, stage: str) -> None: + """Sets up datasets for different tasks on the pipeline. + + This method should setup (load, pre-process, etc) all datasets required + for a particular ``stage`` (fit, validate, test, predict), and keep + them ready to be used on one of the `_dataloader()` functions that are + pertinent for such stage. + + If you have set ``cache_samples``, samples are loaded at this stage and + cached in memory. + + + Parameters + ---------- + + stage + Name of the stage to which the setup is applicable. Can be one of + ``fit``, ``validate``, ``test`` or ``predict``. Each stage + typically uses the following data loaders: + + * ``fit``: uses both train and validation datasets + * ``validate``: uses only the validation dataset + * ``test``: uses only the test dataset + * ``predict``: uses only the test dataset + """ + + if stage == "fit": + for k in ["train"] + self._val_dataset_keys(): + self._setup_dataset(k) + + elif stage == "validate": + for k in self._val_dataset_keys(): + self._setup_dataset(k) + + elif stage == "test": + self._setup_dataset("test") + + elif stage == "predict": + self._setup_dataset("test") + + def teardown(self, stage: str) -> None: + """Unset-up datasets for different tasks on the pipeline. + + This method unsets (unload, remove from memory, etc) all datasets required + for a particular ``stage`` (fit, validate, test, predict). + + If you have set ``cache_samples``, samples are loaded, this may + effectivley release all the associated memory. + + + Parameters + ---------- + + stage + Name of the stage to which the teardown is applicable. Can be one of + ``fit``, ``validate``, ``test`` or ``predict``. Each stage + typically uses the following data loaders: + + * ``fit``: uses both train and validation datasets + * ``validate``: uses only the validation dataset + * ``test``: uses only the test dataset + * ``predict``: uses only the test dataset + """ + + self._datasets = {} + + def train_dataloader(self) -> DataLoader: + """Returns the train data loader.""" + + return torch.utils.data.DataLoader( + self._datasets["train"], + shuffle=(self._train_sampler is None), + batch_size=self._chunk_size, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + sampler=self._train_sampler, + **self._dataloader_multiproc, + ) + + def unshuffled_train_dataloader(self) -> DataLoader: + """Returns the train data loader without shuffling.""" + + return torch.utils.data.DataLoader( + self._datasets["train"], + shuffle=False, + batch_size=self._chunk_size, + drop_last=False, + **self._dataloader_multiproc, + ) + + def val_dataloader(self) -> dict[str, DataLoader]: + """Returns the validation data loader(s)""" + + validation_loader_opts = { + "batch_size": self._chunk_size, + "shuffle": False, + "drop_last": self.drop_incomplete_batch, + "pin_memory": self.pin_memory, + } + validation_loader_opts.update(self._dataloader_multiproc) + + return { + k: torch.utils.data.DataLoader( + self._datasets[k], **validation_loader_opts + ) + for k in self._val_dataset_keys() + } + + def test_dataloader(self) -> dict[str, DataLoader]: + """Returns the test data loader(s)""" + + return dict( + test=torch.utils.data.DataLoader( + self._datasets["test"], + batch_size=self._chunk_size, + shuffle=False, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + **self._dataloader_multiproc, + ) + ) + + def predict_dataloader(self) -> dict[str, DataLoader]: + """Returns the prediction data loader(s)""" + + return self.test_dataloader() diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py deleted file mode 100644 index 15dc32a96f6662b5ba8afbcdf06632b142840f79..0000000000000000000000000000000000000000 --- a/src/ptbench/data/dataset.py +++ /dev/null @@ -1,586 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import csv -import json -import logging -import os -import pathlib - -import torch - -from tqdm import tqdm - -logger = logging.getLogger(__name__) - - -class JSONProtocol: - """Generic multi-protocol/subset filelist dataset that yields samples. - - To create a new dataset, you need to provide one or more JSON formatted - filelists (one per protocol) with the following contents: - - .. code-block:: json - - { - "subset1": [ - [ - "value1", - "value2", - "value3" - ], - [ - "value4", - "value5", - "value6" - ] - ], - "subset2": [ - ] - } - - Your dataset many contain any number of subsets, but all sample entries - must contain the same number of fields. - - - Parameters - ---------- - - protocols : list, dict - Paths to one or more JSON formatted files containing the various - protocols to be recognized by this dataset, or a dictionary, mapping - protocol names to paths (or opened file objects) of CSV files. - Internally, we save a dictionary where keys default to the basename of - paths (list input). - - fieldnames : list, tuple - An iterable over the field names (strings) to assign to each entry in - the JSON file. It should have as many items as fields in each entry of - the JSON file. - - loader : object - A function that receives as input, a context dictionary (with at least - a "protocol" and "subset" keys indicating which protocol and subset are - being served), and a dictionary with ``{fieldname: value}`` entries, - and returns an object with at least 2 attributes: - - * ``key``: which must be a unique string for every sample across - subsets in a protocol, and - * ``data``: which contains the data associated witht this sample - """ - - def __init__(self, protocols, fieldnames): - if isinstance(protocols, dict): - self._protocols = protocols - else: - self._protocols = { - os.path.basename( - str(k).replace("".join(pathlib.Path(k).suffixes), "") - ): k - for k in protocols - } - self.fieldnames = fieldnames - - def check(self, limit=0): - """For each protocol, check if all data can be correctly accessed. - - This function assumes each sample has a ``data`` and a ``key`` - attribute. The ``key`` attribute should be a string, or representable - as such. - - - Parameters - ---------- - - limit : int - Maximum number of samples to check (in each protocol/subset - combination) in this dataset. If set to zero, then check - everything. - - - Returns - ------- - - errors : int - Number of errors found - """ - logger.info("Checking dataset...") - errors = 0 - for proto in self._protocols: - logger.info(f"Checking protocol '{proto}'...") - for name, samples in self.subsets(proto).items(): - logger.info(f"Checking subset '{name}'...") - if limit: - logger.info(f"Checking at most first '{limit}' samples...") - samples = samples[:limit] - for pos, sample in enumerate(samples): - try: - sample.data # may trigger data loading - logger.info(f"{sample.key}: OK") - except Exception as e: - logger.error( - f"Found error loading entry {pos} in subset {name} " - f"of protocol {proto} from file " - f"'{self._protocols[proto]}': {e}" - ) - errors += 1 - return errors - - def subsets(self, protocol): - """Returns all subsets in a protocol. - - This method will load JSON information for a given protocol and return - all subsets of the given protocol after converting each entry through - the loader function. - - Parameters - ---------- - - protocol : str - Name of the protocol data to load - - - Returns - ------- - - subsets : dict - A dictionary mapping subset names to lists of objects (respecting - the ``key``, ``data`` interface). - """ - fileobj = self._protocols[protocol] - if isinstance(fileobj, (str, bytes, pathlib.Path)): - if str(fileobj).endswith(".bz2"): - import bz2 - - with bz2.open(self._protocols[protocol]) as f: - data = json.load(f) - else: - with open(self._protocols[protocol]) as f: - data = json.load(f) - else: - data = json.load(fileobj) - fileobj.seek(0) - - retval = {} - for subset, samples in data.items(): - logger.info(f"Loading subset {subset} samples.") - - retval[subset] = [ - dict(zip(self.fieldnames, k)) - for n, k in enumerate(tqdm(samples)) - ] - - return retval - - -class CSVDataset: - """Generic multi-subset filelist dataset that yields samples. - - To create a new dataset, you only need to provide a CSV formatted filelist - using any separator (e.g. comma, space, semi-colon) with the following - information: - - .. code-block:: text - - value1,value2,value3 - value4,value5,value6 - ... - - Notice that all rows must have the same number of entries. - - Parameters - ---------- - - subsets : list, dict - Paths to one or more CSV formatted files containing the various subsets - to be recognized by this dataset, or a dictionary, mapping subset names - to paths (or opened file objects) of CSV files. Internally, we save a - dictionary where keys default to the basename of paths (list input). - - fieldnames : list, tuple - An iterable over the field names (strings) to assign to each column in - the CSV file. It should have as many items as fields in each row of - the CSV file(s). - - loader : object - A function that receives as input, a context dictionary (with, at - least, a "subset" key indicating which subset is being served), and a - dictionary with ``{key: path}`` entries, and returns a dictionary with - the loaded data. - """ - - def __init__(self, subsets, fieldnames, loader): - if isinstance(subsets, dict): - self._subsets = subsets - else: - self._subsets = { - os.path.basename( - str(k).replace("".join(pathlib.Path(k).suffixes), "") - ): k - for k in subsets - } - self.fieldnames = fieldnames - self._loader = loader - - def check(self, limit=0): - """For each subset, check if all data can be correctly accessed. - - This function assumes each sample has a ``data`` and a ``key`` - attribute. The ``key`` attribute should be a string, or representable - as such. - - - Parameters - ---------- - - limit : int - Maximum number of samples to check (in each protocol/subset - combination) in this dataset. If set to zero, then check - everything. - - - Returns - ------- - - errors : int - Number of errors found - """ - logger.info("Checking dataset...") - errors = 0 - for name in self._subsets.keys(): - logger.info(f"Checking subset '{name}'...") - samples = self.samples(name) - if limit: - logger.info(f"Checking at most first '{limit}' samples...") - samples = samples[:limit] - for pos, sample in enumerate(samples): - try: - sample.data # may trigger data loading - logger.info(f"{sample.key}: OK") - except Exception as e: - logger.error( - f"Found error loading entry {pos} in subset {name} " - f"from file '{self._subsets[name]}': {e}" - ) - errors += 1 - return errors - - def subsets(self): - """Returns all available subsets at once. - - Returns - ------- - - subsets : dict - A dictionary mapping subset names to lists of objects (respecting - the ``key``, ``data`` interface). - """ - return {k: self.samples(k) for k in self._subsets.keys()} - - def samples(self, subset): - """Returns all samples in a subset. - - This method will load CSV information for a given subset and return - all samples of the given subset after passing each entry through the - loading function. - - - Parameters - ---------- - - subset : str - Name of the subset data to load - - - Returns - ------- - - subset : list - A lists of objects (respecting the ``key``, ``data`` interface). - """ - fileobj = self._subsets[subset] - if isinstance(fileobj, (str, bytes, pathlib.Path)): - with open(self._subsets[subset], newline="") as f: - cf = csv.reader(f) - samples = [k for k in cf] - else: - cf = csv.reader(fileobj) - samples = [k for k in cf] - fileobj.seek(0) - - return [ - self._loader( - dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) - ) - for n, k in enumerate(samples) - ] - - -class CachedDataset(torch.utils.data.Dataset): - def __init__( - self, json_protocol, protocol, subset, raw_data_loader, transforms - ): - self.json_protocol = json_protocol - self.subset = subset - self.raw_data_loader = raw_data_loader - self.transforms = transforms - - self._samples = json_protocol.subsets(protocol)[self.subset] - # Dict entry with relative path to files, used during prediction - for s in self._samples: - s["name"] = s["data"] - - logger.info(f"Caching {self.subset} samples") - for sample in tqdm(self._samples): - sample["data"] = self.raw_data_loader(sample["data"]) - - def __getitem__(self, idx): - sample = self._samples[idx].copy() - sample["data"] = self.transforms(sample["data"]) - return sample - - def __len__(self): - return len(self._samples) - - -class RuntimeDataset(torch.utils.data.Dataset): - def __init__( - self, json_protocol, protocol, subset, raw_data_loader, transforms - ): - self.json_protocol = json_protocol - self.subset = subset - self.raw_data_loader = raw_data_loader - self.transforms = transforms - - self._samples = json_protocol.subsets(protocol)[self.subset] - # Dict entry with relative path to files - for s in self._samples: - s["name"] = s["data"] - - def __getitem__(self, idx): - sample = self._samples[idx].copy() - sample["data"] = self.transforms(self.raw_data_loader(sample["data"])) - return sample - - def __len__(self): - return len(self._samples) - - -def get_samples_weights(dataset): - """Compute the weights of all the samples of the dataset to balance it - using the sampler of the dataloader. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the weights to balance each class in the dataset and the - datasets themselves if we have a ConcatDataset. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - samples_weights : :py:class:`torch.Tensor` - the weights for all the samples in the dataset given as input - """ - samples_weights = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - # Weighting only for binary labels - if isinstance(ds._samples[0]["label"], int): - # Groundtruth - targets = [] - for s in ds._samples: - targets.append(s["label"]) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights.append( - torch.tensor([weight[t] for t in targets]) - ) - - else: - # We only weight to sample equally from each dataset - samples_weights.append(torch.full((len(ds),), 1.0 / len(ds))) - - # Concatenate sample weights from all the datasets - samples_weights = torch.cat(samples_weights) - - else: - # Weighting only for binary labels - if isinstance(dataset._samples[0]["label"], int): - # Groundtruth - targets = [] - for s in dataset._samples: - targets.append(s["label"]) - targets = torch.tensor(targets) - - # Count number of samples per class - class_sample_count = torch.tensor( - [ - (targets == t).sum() - for t in torch.unique(targets, sorted=True) - ] - ) - - weight = 1.0 / class_sample_count.float() - - samples_weights = torch.tensor([weight[t] for t in targets]) - - else: - # Equal weights for non-binary labels - samples_weights = torch.ones(len(dataset._samples)) - - return samples_weights - - -def get_positive_weights(dataset): - """Compute the positive weights of each class of the dataset to balance the - BCEWithLogitsLoss criterion. - - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` - and computes the positive weights of each class to use them to have - a balanced loss. - - - Parameters - ---------- - - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported - - - Returns - ------- - - positive_weights : :py:class:`torch.Tensor` - the positive weight of each class in the dataset given as input - """ - targets = [] - - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s["label"]) - - else: - for s in dataset._samples: - targets.append(s["label"]) - - targets = torch.tensor(targets) - - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] - - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]] - ).reshape(-1) - - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) - - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) - - return positive_weights - - -def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): - from torch.nn import BCEWithLogitsLoss - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - train_dataset = datamodule.train_dataset - validation_dataset = datamodule.validation_dataset - - # Redefine a weighted criterion if possible - if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = get_positive_weights(train_dataset) - model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) - else: - logger.warning("Weighted criterion not supported") - - if validation_dataset is not None: - # Redefine a weighted valid criterion if possible - if ( - isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) - or criterion_valid is None - ): - positive_weights = get_positive_weights(validation_dataset) - model.hparams.criterion_valid = BCEWithLogitsLoss( - pos_weight=positive_weights - ) - else: - logger.warning("Weighted valid criterion not supported") - - -def normalize_data(normalization, model, datamodule): - from torch.utils.data import DataLoader - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - train_dataset = datamodule.train_dataset - - # Create z-normalization model layer if needed - if normalization == "imagenet": - model.normalizer.set_mean_std( - [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - ) - logger.info("Z-normalization with ImageNet mean and std") - elif normalization == "current": - # Compute mean/std of current train subset - temp_dl = DataLoader( - dataset=train_dataset, batch_size=len(train_dataset) - ) - - data = next(iter(temp_dl)) - mean = data[1].mean(dim=[0, 2, 3]) - std = data[1].std(dim=[0, 2, 3]) - - model.normalizer.set_mean_std(mean, std) - - # Format mean and std for logging - mean = str( - [ - round(x, 3) - for x in ((mean * 10**3).round() / (10**3)).tolist() - ] - ) - std = str( - [ - round(x, 3) - for x in ((std * 10**3).round() / (10**3)).tolist() - ] - ) - logger.info(f"Z-normalization with mean {mean} and std {std}") diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac31b9ce7fbce85fb688b394c99d591b83049f7f --- /dev/null +++ b/src/ptbench/data/image_utils.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +"""Data loading code.""" + +import pathlib + +import numpy +import PIL.Image + + +class SingleAutoLevel16to8: + """Converts a 16-bit image to 8-bit representation using "auto-level". + + This transform assumes that the input image is gray-scaled. + + To auto-level, we calculate the maximum and the minimum of the image, and + consider such a range should be mapped to the [0,255] range of the + destination image. + """ + + def __call__(self, img): + imin, imax = img.getextrema() + irange = imax - imin + return PIL.Image.fromarray( + numpy.round( + 255.0 * (numpy.array(img).astype(float) - imin) / irange + ).astype("uint8"), + ).convert("L") + + +class RemoveBlackBorders: + """Remove black borders of CXR.""" + + def __init__(self, threshold=0): + self.threshold = threshold + + def __call__(self, img): + img = numpy.asarray(img) + mask = numpy.asarray(img) > self.threshold + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + +def load_pil(path: str | pathlib.Path) -> PIL.Image.Image: + """Loads a sample data. + + Parameters + ---------- + + path + The full path leading to the image to be loaded + + + Returns + ------- + + image + A PIL image + """ + return PIL.Image.open(path) + + +def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image: + """Loads a sample data. + + Parameters + ---------- + + path + The full path leading to the image to be loaded + + + Returns + ------- + + image + A PIL image in grayscale mode + """ + return load_pil(path).convert("L") + + +def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image: + """Loads a sample data. + + Parameters + ---------- + + path + The full path leading to the image to be loaded + + + Returns + ------- + + image + A PIL image in RGB mode + """ + return load_pil(path).convert("RGB") diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py deleted file mode 100644 index d6e86ed06dd04abb230830ebc8b13e2ea9720cfa..0000000000000000000000000000000000000000 --- a/src/ptbench/data/loader.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -"""Data loading code.""" - -import PIL.Image - - -def load_pil(path): - """Loads a sample data. - - Parameters - ---------- - - path : str - The full path leading to the image to be loaded - - - Returns - ------- - - image : PIL.Image.Image - A PIL image - """ - return PIL.Image.open(path) - - -def load_pil_baw(path): - """Loads a sample data. - - Parameters - ---------- - - path : str - The full path leading to the image to be loaded - - - Returns - ------- - - image : PIL.Image.Image - A PIL image in grayscale mode - """ - return load_pil(path).convert("L") - - -def load_pil_rgb(path): - """Loads a sample data. - - Parameters - ---------- - - path : str - The full path leading to the image to be loaded - - - Returns - ------- - - image : PIL.Image.Image - A PIL image in RGB mode - """ - return load_pil(path).convert("RGB") diff --git a/src/ptbench/data/padchest/__init__.py b/src/ptbench/data/padchest/__init__.py index af1dd3ec9f8cbaa1b3e32f5195363592602d94ac..f6f39a9ebe98eb68a88374b8d83a70447a4ad995 100644 --- a/src/ptbench/data/padchest/__init__.py +++ b/src/ptbench/data/padchest/__init__.py @@ -264,7 +264,7 @@ json_dataset = JSONDataset( def _maker(protocol, resize_size=512, cc_size=512, RGB=True): import torchvision.transforms as transforms - from ..transforms import SingleAutoLevel16to8 + from ..loader import SingleAutoLevel16to8 post_transforms = [] if not RGB: diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index 54eb632684a83e16ac28481f8499137f594b662a..1645962e8cc00443399dd60b88f017c71824e086 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems. * Reference: [MONTGOMERY-SHENZHEN-2014]_ * Original resolution (height x width or width x height): 3000 x 3000 or less * Split reference: none -* Protocol ``default``: - * Training samples: 64% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels) """ import importlib.resources -import os - -from clapper.logging import setup - -from ...utils.rc import load_rc -from ..loader import load_pil_baw - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - _protocols = [ importlib.resources.files(__name__).joinpath("default.json.bz2"), @@ -43,11 +32,3 @@ _protocols = [ importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), ] - -_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) - - -def _raw_data_loader(img_path): - raw_data = load_pil_baw(os.path.join(_datadir, img_path)) - - return raw_data diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index b5fb23f59277f1f6511c6707929856cc2357be2f..bfe93f44faaa9df235f357c1cc3a927412f4a011 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -2,27 +2,46 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for TB detection (default protocol) +"""Shenzhen datamodule for computer-aided diagnosis (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" +See :py:mod:`ptbench.data.shenzhen` for more database details. + +This configuration: + +* Raw data input (on disk): + + * PNG images (black and white, encoded as color images) + * Variable width and height: -from clapper.logging import setup + * widths: from 1130 to 3001 pixels + * heights: from 948 to 3001 pixels -from ..transforms import ElasticDeformation -from .utils import ShenzhenDataModule +* Output image: -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + * Transforms: + + * Load raw PNG with :py:mod:`PIL` + * Remove black borders + * Torch resizing(512px, 512px) + * Torch center cropping (512px, 512px) + + * Final specifications: + + * Fixed resolution: 512x512 pixels + * Color RGB encoding +""" -protocol_name = "default" +import importlib.resources -augmentation_transforms = [ElasticDeformation(p=0.8)] +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -datamodule = ShenzhenDataModule( - protocol="default", - model_transforms=[], - augmentation_transforms=augmentation_transforms, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "default.json.bz2" + ) + ), + raw_data_loader=RawDataLoader(), ) diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e983fe7028d858118efb9ba41e560aee8e95845a --- /dev/null +++ b/src/ptbench/data/shenzhen/loader.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for computer-aided diagnosis. + +The standard digital image database for Tuberculosis is created by the National +Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s +Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from +out-patient clinics, and were captured as part of the daily routine using +Philips DR Digital Diagnose systems. + +* Reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* Split reference: none +* Protocol ``default``: + + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) +""" + +import os + +import torchvision.transforms + +from ...utils.rc import load_rc +from ..image_utils import RemoveBlackBorders, load_pil_baw +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Shenzen dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + + transform + Transforms that are always applied to the loaded raw images. + """ + + datadir: str + transform: torchvision.transforms.Compose + + def __init__(self): + self.datadir = load_rc().get( + "datadir.shenzhen", os.path.realpath(os.curdir) + ) + + self.transform = torchvision.transforms.Compose( + [ + RemoveBlackBorders(), + torchvision.transforms.Resize(512), + torchvision.transforms.CenterCrop(512), + torchvision.transforms.ToTensor(), + ] + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + tensor = self.transform( + load_pil_baw(os.path.join(self.datadir, sample[0])) + ) + return tensor, dict(label=sample[1]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py index 2d93c07edb4c0824caa8149737ec42783966ad4c..f45f601d6e84a04e4610319d1f27631ff32ee69c 100644 --- a/src/ptbench/data/shenzhen/rgb.py +++ b/src/ptbench/data/shenzhen/rgb.py @@ -2,81 +2,40 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) +"""Shenzhen datamodule for computer-aided diagnosis (default protocol) -* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" - -from clapper.logging import setup - -from ....data import return_subsets -from ....data.base_datamodule import BaseDataModule -from ....data.dataset import JSONProtocol -from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols +See :py:mod:`ptbench.data.shenzhen` for dataset details. -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - cache_samples=False, - multiproc_kwargs=None, - data_transforms=[], - model_transforms=[], - train_transforms=[], - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - self.cache_samples = cache_samples - self.has_setup_fit = False +This configuration: +* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms` +* augmentations: elastic deformation (probability = 80%) +* output image resolution: 512x512 pixels +""" - self.data_transforms = data_transforms - self.model_transforms = model_transforms - self.train_transforms = train_transforms +import importlib.resources - """[ - transforms.ToPILImage(), - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ]""" +from torchvision import transforms - def setup(self, stage: str): - if self.cache_samples: - logger.info( - "Argument cache_samples set to True. Samples will be loaded in memory." - ) - samples_loader = _cached_loader - else: - logger.info( - "Argument cache_samples set to False. Samples will be loaded at runtime." - ) - samples_loader = _delayed_loader +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .raw_data_loader import raw_data_loader - self.json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - loader=samples_loader, - post_transforms=self.post_transforms, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "default.json.bz2" ) - - if not self.has_setup_fit and stage == "fit": - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - ) = return_subsets(self.json_protocol, "default", stage) - self.has_setup_fit = True - - -datamodule = DefaultModule + ), + raw_data_loader=raw_data_loader, + cache_samples=False, + # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, + model_transforms=[ + transforms.ToPILImage(), + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ], + # batch_size = 1, + # batch_chunk_count = 1, + # drop_incomplete_batch = False, + # parallel = -1, +) diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py deleted file mode 100644 index 1521b674212feec942e73df081e7b20c19d89e29..0000000000000000000000000000000000000000 --- a/src/ptbench/data/shenzhen/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the -National Library of Medicine, Maryland, USA in collaboration with Shenzhen -No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. -The Chest X-rays are from out-patient clinics, and were captured as part of -the daily routine using Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" -from clapper.logging import setup -from torchvision import transforms - -from ..base_datamodule import BaseDataModule -from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset -from ..shenzhen import _protocols, _raw_data_loader -from ..transforms import RemoveBlackBorders - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class ShenzhenDataModule(BaseDataModule): - def __init__( - self, - protocol="default", - model_transforms=[], - augmentation_transforms=[], - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - cache_samples=False, - parallel=-1, - ): - super().__init__( - batch_size=batch_size, - drop_incomplete_batch=drop_incomplete_batch, - batch_chunk_count=batch_chunk_count, - parallel=parallel, - ) - - self._cache_samples = cache_samples - self._has_setup_fit = False - self._has_setup_predict = False - self._protocol = protocol - - self.raw_data_transforms = [ - RemoveBlackBorders(), - transforms.Resize(512), - transforms.CenterCrop(512), - transforms.ToTensor(), - ] - - self.model_transforms = model_transforms - - self.augmentation_transforms = augmentation_transforms - - def setup(self, stage: str): - json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - ) - - if self._cache_samples: - dataset = CachedDataset - else: - dataset = RuntimeDataset - - if not self._has_setup_fit and stage == "fit": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=True), - ) - - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_fit = True - - if not self._has_setup_predict and stage == "predict": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_predict = True diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py new file mode 100644 index 0000000000000000000000000000000000000000..78fe7e33886cd47265d0a7f217bcb55f799c9fa4 --- /dev/null +++ b/src/ptbench/data/split.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import csv +import importlib.abc +import json +import logging +import pathlib +import typing + +import torch + +from .typing import DatabaseSplit, RawDataLoader + +logger = logging.getLogger(__name__) + + +class JSONDatabaseSplit(DatabaseSplit): + """Defines a loader that understands a database split (train, test, etc) in + JSON format. + + To create a new database split, you need to provide a JSON formatted + dictionary in a file, with contents similar to the following: + + .. code-block:: json + + { + "subset1": [ + [ + "sample1-data1", + "sample1-data2", + "sample1-data3", + ], + [ + "sample2-data1", + "sample2-data2", + "sample2-data3", + ] + ], + "subset2": [ + [ + "sample42-data1", + "sample42-data2", + "sample42-data3", + ], + ] + } + + Your database split many contain any number of subsets (dictionary keys). + For simplicity, we recommend all sample entries are formatted similarly so + that raw-data-loading is simplified. Use the function + :py:func:`check_database_split_loading` to test raw data loading and fine + tune the dataset split, or its loading. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. + + + Parameters + ---------- + + path + Absolute path to a JSON formatted file containing the database split to be + recognized by this object. + """ + + def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable): + if isinstance(path, str): + path = pathlib.Path(path) + self.path = path + self.subsets = self._load_split_from_disk() + + def _load_split_from_disk(self) -> DatabaseSplit: + """Loads all subsets in a split from its file system representation. + + This method will load JSON information for the current split and return + all subsets of the given split after converting each entry through the + loader function. + + + Returns + ------- + + subsets : dict + A dictionary mapping subset names to lists of JSON objects + """ + + if str(self.path).endswith(".bz2"): + logger.debug(f"Loading database split from {str(self.path)}...") + with __import__("bz2").open(self.path) as f: + return json.load(f) + else: + with self.path.open() as f: + return json.load(f) + + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] + + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) + + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) + + +class CSVDatabaseSplit(DatabaseSplit): + """Defines a loader that understands a database split (train, test, etc) in + CSV format. + + To create a new database split, you need to provide one or more CSV + formatted files, each representing a subset of this split, containing the + sample data (one per row). Example: + + Inside the directory ``my-split/``, one can file files ``train.csv``, + ``validation.csv``, and ``test.csv``. Each file has a structure similar to + the following: + + .. code-block:: text + + sample1-value1,sample1-value2,sample1-value3 + sample2-value1,sample2-value2,sample2-value3 + ... + + Each file in the provided directory defines the subset name on the split. + So, the file ``train.csv`` will contain the data from the ``train`` subset, + and so on. + + Objects of this class behave like a dictionary in which keys are subset + names in the split, and values represent samples data and meta-data. + + + Parameters + ---------- + + directory + Absolute path to a directory containing the database split layed down + as a set of CSV files, one per subset. + """ + + def __init__( + self, directory: pathlib.Path | str | importlib.abc.Traversable + ): + if isinstance(directory, str): + directory = pathlib.Path(directory) + assert ( + directory.is_dir() + ), f"`{str(directory)}` is not a valid directory" + self.directory = directory + self.subsets = self._load_split_from_disk() + + def _load_split_from_disk(self) -> DatabaseSplit: + """Loads all subsets in a split from its file system representation. + + This method will load CSV information for the current split and return all + subsets of the given split after converting each entry through the + loader function. + + + Returns + ------- + + subsets : dict + A dictionary mapping subset names to lists of JSON objects + """ + + retval: DatabaseSplit = {} + for subset in self.directory.iterdir(): + if str(subset).endswith(".csv.bz2"): + logger.debug(f"Loading database split from {subset}...") + with __import__("bz2").open(subset) as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv.bz2")]] = [ + k for k in reader + ] + elif str(subset).endswith(".csv"): + with subset.open() as f: + reader = csv.reader(f) + retval[subset.name[: -len(".csv")]] = [k for k in reader] + else: + logger.debug( + f"Ignoring file {subset} in CSVDatabaseSplit readout" + ) + return retval + + def __getitem__(self, key: str) -> typing.Sequence[typing.Any]: + """Accesses subset ``key`` from this split.""" + return self.subsets[key] + + def __iter__(self): + """Iterates over the subsets.""" + return iter(self.subsets) + + def __len__(self) -> int: + """How many subsets we currently have.""" + return len(self.subsets) + + +def check_database_split_loading( + database_split: DatabaseSplit, + loader: RawDataLoader, + limit: int = 0, +) -> int: + """For each subset in the split, check if all data can be correctly loaded + using the provided loader function. + + This function will return the number of errors loading samples, and will + log more detailed information to the logging stream. + + + Parameters + ---------- + + database_split + A mapping that, contains the database split. Each key represents the + name of a subset in the split. Each value is a (potentially complex) + object that represents a single sample. + + loader + A loader object that knows how to handle full-samples or just labels. + + limit + Maximum number of samples to check (in each split/subset + combination) in this dataset. If set to zero, then check + everything. + + + Returns + ------- + + errors + Number of errors found + """ + logger.info( + "Checking if can load all samples in all subsets of this split..." + ) + errors = 0 + for subset in database_split.keys(): + samples = subset if not limit else subset[:limit] + for pos, sample in enumerate(samples): + try: + data, _ = loader.sample(sample) + assert isinstance(data, torch.Tensor) + except Exception as e: + logger.info( + f"Found error loading entry {pos} in subset `{subset}`: {e}" + ) + errors += 1 + return errors diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index 6d3c17d243a4987f889bae8da0bb6a21ce73459b..ad516e194570dca2b5794959f816a0e717733a9d 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates from torchvision import transforms -class SingleAutoLevel16to8: - """Converts a 16-bit image to 8-bit representation using "auto-level". - - This transform assumes that the input image is gray-scaled. - - To auto-level, we calculate the maximum and the minimum of the - image, and - consider such a range should be mapped to the [0,255] range of the - destination image. - """ - - def __call__(self, img): - imin, imax = img.getextrema() - irange = imax - imin - return PIL.Image.fromarray( - numpy.round( - 255.0 * (numpy.array(img).astype(float) - imin) / irange - ).astype("uint8"), - ).convert("L") - - -class RemoveBlackBorders: - """Remove black borders of CXR.""" - - def __init__(self, threshold=0): - self.threshold = threshold - - def __call__(self, img): - img = numpy.asarray(img) - mask = numpy.asarray(img) > self.threshold - return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) - - class ElasticDeformation: """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. @@ -68,7 +35,7 @@ class ElasticDeformation: spline_order=1, mode="nearest", random_state=numpy.random, - p=1, + p=1.0, ): self.alpha = alpha self.sigma = sigma @@ -79,13 +46,15 @@ class ElasticDeformation: def __call__(self, img): if random.random() < self.p: - img = transforms.ToPILImage()(img) + assert img.ndim == 3 + # Input tensor is of shape C x H x W + # If the tensor only contains one channel, this conversion results in H x W. + # With 3 channels, we get H x W x C + img = transforms.ToPILImage()(img) img = numpy.asarray(img) - assert img.ndim == 2 - - shape = img.shape + shape = img.shape[:2] dx = ( gaussian_filter( @@ -114,9 +83,22 @@ class ElasticDeformation: numpy.reshape(y + dy, (-1, 1)), ] result = numpy.empty_like(img) - result[:, :] = map_coordinates( - img[:, :], indices, order=self.spline_order, mode=self.mode - ).reshape(shape) + + if img.ndim == 2: + result[:, :] = map_coordinates( + img[:, :], indices, order=self.spline_order, mode=self.mode + ).reshape(shape) + + else: + for i in range(img.shape[2]): + result[:, :, i] = map_coordinates( + img[:, :, i], + indices, + order=self.spline_order, + mode=self.mode, + ).reshape(shape) + return transforms.ToTensor()(PIL.Image.fromarray(result)) + else: return img diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..344c1294df6777f88acdde23dec51d40fc51e31e --- /dev/null +++ b/src/ptbench/data/typing.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Defines most common types used in code.""" + +import typing + +import torch +import torch.utils.data + +Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]] +"""Definition of a sample. + +First parameter + The actual data that is input to the model + +Second parameter + A dictionary containing a named set of meta-data. One the most common is + the ``label`` entry. +""" + + +class RawDataLoader: + """A loader object can load samples and labels from storage.""" + + def sample(self, _: typing.Any) -> Sample: + """Loads whole samples from media.""" + raise NotImplementedError("You must implement the `sample()` method") + + def label(self, k: typing.Any) -> int: + """Loads only sample label from media. + + If you do not override this implementation, then, by default, + this method will call :py:meth:`sample` to load the whole sample + and extract the label. + """ + return self.sample(k)[1]["label"] + + +Transform = typing.Callable[[torch.Tensor], torch.Tensor] +"""A callable, that transforms tensors into (other) tensors. + +Typically used in data-processing pipelines inside pytorch. +""" + +TransformSequence = typing.Sequence[Transform] +"""A sequence of transforms.""" + +DatabaseSplit = dict[str, typing.Sequence[typing.Any]] +"""The definition of a database script. + +A database script maps subset names to sequences of objects that, +through RawDataLoader's eventually become Samples in the processing +pipeline. +""" + + +class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): + """Our own definition of a pytorch Dataset, with interesting properties. + + We iterate over Sample objects in this case. Our datasets always + provide a dunder len method. + """ + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + raise NotImplementedError("You must implement the `labels()` method") + + +DataLoader = torch.utils.data.DataLoader[Sample] +"""Our own augmentation definition of a pytorch DataLoader. + +We iterate over Sample objects in this case. +""" diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index d0ac43f98e21b8ce6803797d6a1fde38c6302660..350140a8516ddef43e89323e65b746ca7c479182 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -1,155 +1,403 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + import csv +import logging import os +import pathlib import time +import typing -from collections import defaultdict +import lightning.pytorch +import lightning.pytorch.callbacks +import torch -import numpy +from ..utils.resources import ResourceMonitor -from lightning.pytorch import Callback -from lightning.pytorch.callbacks import BasePredictionWriter +logger = logging.getLogger(__name__) -# This ensures CSVLogger logs training and evaluation metrics on the same line -# CSVLogger only accepts numerical values, not strings -class LoggingCallback(Callback): - """Lightning callback to log various training metrics and device - information.""" +class LoggingCallback(lightning.pytorch.Callback): + """Callback to log various training metrics and device information. - def __init__(self, resource_monitor): - super().__init__() - self.training_loss = [] - self.validation_loss = [] - self.extra_validation_loss = defaultdict(list) - self.start_training_time = 0 - self.start_epoch_time = 0 + It ensures CSVLogger logs training and evaluation metrics on the same line + Note that a CSVLogger only accepts numerical values, and not strings. - self.resource_monitor = resource_monitor - self.max_queue_retries = 2 - def on_train_start(self, trainer, pl_module): - self.start_training_time = time.time() + Parameters + ---------- - def on_train_epoch_start(self, trainer, pl_module): - self.start_epoch_time = time.time() + resource_monitor + A monitor that watches resource usage (CPU/GPU) in a separate process + and totally asynchronously with the code execution. + """ - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self.training_loss.append(outputs["loss"].item()) + def __init__(self, resource_monitor: ResourceMonitor): + super().__init__() - def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx + # lists of number of samples/batch and average losses + # - we use this later to compute overall epoch losses + self._training_epoch_loss: tuple[list[int], list[float]] = ([], []) + self._validation_epoch_loss: dict[ + int, tuple[list[int], list[float]] + ] = {} + + # timers + self._start_training_time = 0.0 + self._start_training_epoch_time = 0.0 + self._start_validation_epoch_time = 0.0 + + # log accumulators for a single flush at each training cycle + self._to_log: dict[str, float] = {} + + # helpers for CPU and GPU utilisation + self._resource_monitor = resource_monitor + self._max_queue_retries = 2 + + def on_train_start( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, ): - self.validation_loss.append(outputs["validation_loss"].item()) + """Callback to be executed **before** the whole training starts. - if len(outputs) > 1: - extra_validation_keys = outputs.keys().remove("validation_loss") - for extra_validation_loss_key in extra_validation_keys: - self.extra_validation_loss[extra_validation_loss_key].append( - outputs[extra_validation_loss_key] - ) + This method is executed whenever you *start* training a module. - def on_validation_epoch_end(self, trainer, pl_module): - self.resource_monitor.trigger_summary() - self.epoch_time = time.time() - self.start_epoch_time - eta_seconds = self.epoch_time * ( - trainer.max_epochs - trainer.current_epoch - ) - current_time = time.time() - self.start_training_time + Parameters + --------- - def _compute_batch_loss(losses, num_chunks): - # When accumulating gradients, partial losses need to be summed per batch before averaging - if num_chunks != 1: - # The loss we get is scaled by the number of accumulation steps - losses = numpy.multiply(losses, num_chunks) + trainer + The Lightning trainer object - if len(losses) % num_chunks > 0: - num_splits = (len(losses) // num_chunks) + 1 - else: - num_splits = len(losses) // num_chunks + pl_module + The lightning module that is being trained + """ + self._start_training_time = time.time() - batched_losses = numpy.array_split(losses, num_splits) + def on_train_epoch_start( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + ) -> None: + """Callback to be executed **before** every training batch starts. - summed_batch_losses = [] + This method is executed whenever a training batch starts. Presumably, + batches happen as often as possible. You want to make this code very + fast. Do not log things to the terminal or the such, or do complicated + (lengthy) calculations. - for b in batched_losses: - summed_batch_losses.append(numpy.average(b)) + .. warning:: - return summed_batch_losses + This is executed **while** you are training. Be very succint or + face the consequences of slow training! - # No gradient accumulation, we already have the batch losses - else: - return losses - # Do not log during sanity check as results are not relevant - if not trainer.sanity_checking: - # We get partial loses when using gradient accumulation - self.training_loss = _compute_batch_loss( - self.training_loss, trainer.accumulate_grad_batches - ) - self.validation_loss = _compute_batch_loss( - self.validation_loss, trainer.accumulate_grad_batches - ) + Parameters + --------- - self.log("total_time", current_time) - self.log("eta", eta_seconds) - self.log("loss", numpy.average(self.training_loss)) - self.log( - "learning_rate", pl_module.hparams["optimizer_configs"]["lr"] - ) - self.log("validation_loss", numpy.sum(self.validation_loss)) - - if len(self.extra_validation_loss) > 0: - for ( - extra_valid_loss_key, - extra_valid_loss_values, - ) in self.extra_validation_loss.items: - self.log( - extra_valid_loss_key, numpy.sum(extra_valid_loss_values) - ) + trainer + The Lightning trainer object - queue_retries = 0 - # In case the resource monitor takes longer to fetch data from the queue, we wait - # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue - while ( - self.resource_monitor.data is None - and queue_retries < self.max_queue_retries - ): - queue_retries = queue_retries + 1 - print( - f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s" + pl_module + The lightning module that is being trained + """ + self._start_training_epoch_time = time.time() + self._training_epoch_loss = ([], []) + + def on_train_epoch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + ): + """Callback to be executed **after** every training epoch ends. + + This method is executed whenever a training epoch ends. Presumably, + epochs happen as often as possible. You want to make this code + relatively fast to avoid significative runtime slow-downs. + + + Parameters + ---------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + """ + + # summarizes resource usage since the last checkpoint + # clears internal buffers and starts accumulating again. + self._resource_monitor.checkpoint() + + # evaluates this training epoch total time, and log it + epoch_time = time.time() - self._start_training_epoch_time + + # Compute overall training loss considering batches and sizes + # We disconsider accumulate_grad_batches and assume they were all of + # the same size. This way, the average of averages is the overall + # average. + self._to_log["train_loss"] = torch.mean( + torch.tensor(self._training_epoch_loss[0]) + * torch.tensor(self._training_epoch_loss[1]) + ).item() + + self._to_log["train_epoch_time"] = epoch_time + self._to_log["learning_rate"] = pl_module.optimizers().defaults["lr"] + + metrics = self._resource_monitor.data + if metrics is not None: + for metric_name, metric_value in metrics.items(): + self._to_log[f"train_{metric_name}"] = float(metric_value) + else: + logger.warning( + "Unable to fetch monitoring information from " + "resource monitor. CPU/GPU utilisation will be " + "missing." ) - time.sleep(self.resource_monitor.interval) - if queue_retries >= self.max_queue_retries: - print( - f"Unable to fetch monitoring information from queue after {queue_retries} retries" + # if no validation dataloaders, complete cycle by the end of the + # training epoch, by logging all values to the logger + self.on_cycle_end(trainer, pl_module) + + def on_train_batch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + outputs: typing.Mapping[str, torch.Tensor], + batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]], + batch_idx: int, + ) -> None: + """Callback to be executed **after** every training batch ends. + + This method is executed whenever a training batch ends. Presumably, + batches happen as often as possible. You want to make this code very + fast. Do not log things to the terminal or the such, or do complicated + (lengthy) calculations. + + .. warning:: + + This is executed **while** you are training. Be very succint or + face the consequences of slow training! + + + Parameters + ---------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + + outputs + The outputs of the module's ``training_step`` + + batch + The data that the training step received + + batch_idx + The relative number of the batch + """ + self._training_epoch_loss[0].append(batch[0].shape[0]) + self._training_epoch_loss[1].append(outputs["loss"].item()) + + def on_validation_epoch_start( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + ) -> None: + """Callback to be executed **before** every validation batch starts. + + This method is executed whenever a validation batch starts. Presumably, + batches happen as often as possible. You want to make this code very + fast. Do not log things to the terminal or the such, or do complicated + (lengthy) calculations. + + .. warning:: + + This is executed **while** you are training. Be very succint or + face the consequences of slow training! + + + Parameters + --------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + """ + self._start_validation_epoch_time = time.time() + self._validation_epoch_loss = {} + + def on_validation_epoch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + ) -> None: + """Callback to be executed **after** every validation epoch ends. + + This method is executed whenever a validation epoch ends. Presumably, + epochs happen as often as possible. You want to make this code + relatively fast to avoid significative runtime slow-downs. + + + Parameters + ---------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + """ + + # summarizes resource usage since the last checkpoint + # clears internal buffers and starts accumulating again. + self._resource_monitor.checkpoint() + + epoch_time = time.time() - self._start_validation_epoch_time + self._to_log["validation_epoch_time"] = epoch_time + + metrics = self._resource_monitor.data + if metrics is not None: + for metric_name, metric_value in metrics.items(): + self._to_log[f"validation_{metric_name}"] = float(metric_value) + else: + logger.warning( + "Unable to fetch monitoring information from " + "resource monitor. CPU/GPU utilisation will be " + "missing." ) - assert self.resource_monitor.q.empty() + # Compute overall validation losses considering batches and sizes + # We disconsider accumulate_grad_batches and assume they were all + # of the same size. This way, the average of averages is the + # overall average. + for key in sorted(self._validation_epoch_loss.keys()): + if key == 0: + name = "validation_loss" + else: + name = f"validation_loss_{key}" - # Do not log during sanity check as results are not relevant - if not trainer.sanity_checking: - for metric_name, metric_value in self.resource_monitor.data: - self.log(metric_name, float(metric_value)) + self._to_log[name] = torch.mean( + torch.tensor(self._validation_epoch_loss[key][0]) + * torch.tensor(self._validation_epoch_loss[key][1]) + ).item() + + def on_validation_batch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + outputs: torch.Tensor, + batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Callback to be executed **after** every validation batch ends. + + This method is executed whenever a validation batch ends. Presumably, + batches happen as often as possible. You want to make this code very + fast. Do not log things to the terminal or the such, or do complicated + (lengthy) calculations. - self.resource_monitor.data = None + .. warning:: - self.training_loss = [] - self.validation_loss = [] + This is executed **while** you are training. Be very succint or + face the consequences of slow training! -class PredictionsWriter(BasePredictionWriter): + Parameters + ---------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + + outputs + The outputs of the module's ``training_step`` + + batch + The data that the training step received + + batch_idx + The relative number of the batch + + dataloader_idx + Index of the dataloader used during validation. Use this to figure + out which dataset was used for this validation epoch. + """ + size, value = self._validation_epoch_loss.setdefault( + dataloader_idx, ([], []) + ) + size.append(batch[0].shape[0]) + value.append(outputs.item()) + + def on_cycle_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + ) -> None: + """Called when the training/validation cycle has ended. + + This function will log all relevant values to the various loggers. It + is supposed to be called by the end of the training cycle (consisting + of a training and validation step). + + + Parameters + ---------- + + trainer + The Lightning trainer object + + pl_module + The lightning module that is being trained + """ + + # collect some final time for the whole training cycle + # Note: logging should happen at on_validation_end(), but + # apparently you can't log from there + overall_cycle_time = time.time() - self._start_training_epoch_time + self._to_log["train_cycle_time"] = overall_cycle_time + self._to_log["total_time"] = time.time() - self._start_training_time + self._to_log["eta"] = overall_cycle_time * ( + trainer.max_epochs - trainer.current_epoch # type: ignore + ) + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + for k in sorted(self._to_log.keys()): + pl_module.log(k, self._to_log[k]) + self._to_log = {} + + +class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): """Lightning callback to write predictions to a file.""" - def __init__(self, output_dir, logfile_fields, write_interval): + def __init__( + self, + output_dir: str | pathlib.Path, + logfile_fields: typing.Sequence[str], + write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"], + ): super().__init__(write_interval) self.output_dir = output_dir self.logfile_fields = logfile_fields def write_on_epoch_end( - self, trainer, pl_module, predictions, batch_indices - ): + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + predictions: typing.Sequence[typing.Any], + batch_indices: typing.Sequence[typing.Any] | None, + ) -> None: for dataloader_idx, dataloader_results in enumerate(predictions): dataloader_name = list( trainer.datamodule.predict_dataloader().keys() diff --git a/src/ptbench/engine/device.py b/src/ptbench/engine/device.py new file mode 100644 index 0000000000000000000000000000000000000000..253bba0d9da3bedca6010b2ad87937b9da4f08e0 --- /dev/null +++ b/src/ptbench/engine/device.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os + +import torch +import torch.backends + +logger = logging.getLogger(__name__) + + +def _split_int_list(s: str) -> list[int]: + """Splits a list of integers encoded in a string (e.g. "1,2,3") into a + Python list of integers (e.g. ``[1, 2, 3]``).""" + return [int(k.strip()) for k in s.split(",")] + + +class DeviceManager: + """This class is used to manage the Lightning Accelerator and Pytorch + Devices. + + It takes the user input, in the form of a string defined by + ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can + translate to the right incarnation of Pytorch devices or Lightning + Accelerators to interface with the various frameworks. + + Instances of this class also manage the environment variable + ``$CUDA_VISIBLE_DEVICES`` if necessary. + + + Parameters + ---------- + + name + The name of the device to use, in the form of a string defined by + ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In + the specific case of ``cuda``, one can also specify a device to use + either by adding ``:N``, where N is the zero-indexed board number on + the computer, or by setting the environment variable + ``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the + current process. + """ + + SUPPORTED = ("cpu", "cuda", "mps") + + def __init__(self, name: str): + parts = name.split(":", 1) + self.device_type = parts[0] + self.device_ids: list[int] = [] + if len(parts) > 1: + self.device_ids = _split_int_list(parts[1]) + + if self.device_type == "cuda": + visible_env = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_env: + visible = _split_int_list(visible_env) + if self.device_ids and visible != self.device_ids: + logger.warning( + f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} " + f"- overriding environment with value set on `name`" + ) + else: + self.device_ids = visible + + # make sure that it is consistent with the environment + if self.device_ids: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + [str(k) for k in self.device_ids] + ) + + if self.device_type not in DeviceManager.SUPPORTED: + raise RuntimeError( + f"Unsupported device type `{self.device_type}`. " + f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`" + ) + + if self.device_ids and self.device_type in ("cpu", "mps"): + logger.warning( + f"Cannot pin device ids if using cpu or mps backend. " + f"Setting `name` to {name} is non-sensical. Ignoring..." + ) + + # check if the device_type that was set has support compiled in + if self.device_type == "cuda": + assert hasattr(torch, "cuda") and torch.cuda.is_available(), ( + f"User asked for device = `{name}`, but CUDA support is " + f"not compiled into pytorch!" + ) + + if self.device_type == "mps": + assert ( + hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() # type:ignore + ), ( + f"User asked for device = `{name}`, but MPS support is " + f"not compiled into pytorch!" + ) + + def torch_device(self) -> torch.device: + """Returns a representation of the torch device to use by default. + + .. warning:: + + If a list of devices is set, then this method only returns the first + device. This may impact Nvidia GPU logging in the case multiple + GPU cards are used. + + + Returns + ------- + + device + The **first** torch device (if a list of ids is set). + """ + + if self.device_type in ("cpu", "mps"): + return torch.device(self.device_type) + elif self.device_type == "cuda": + if not self.device_ids: + return torch.device(self.device_type) + else: + return torch.device(self.device_type, self.device_ids[0]) + + # if you get to this point, this is an unexpected RuntimeError + raise RuntimeError( + f"Unexpected device type {self.device_type} lacks support" + ) + + def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]: + """Returns the lightning accelerator setup. + + Returns + ------- + + accelerator + The lightning accelerator to use + + devices + The lightning devices to use + """ + + devices: int | list[int] | str = self.device_ids + if not devices: + devices = "auto" + elif self.device_type == "mps": + devices = 1 + + return self.device_type, devices diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 2c8bdc55f161ce567ddd7cb6641ac728ce3b7dbd..10121af1091a0e3a186efd4c338f4af0ad4b9cf5 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -7,70 +7,73 @@ import logging import os import shutil -from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger -from lightning.pytorch.utilities.model_summary import ModelSummary +import lightning.pytorch +import lightning.pytorch.callbacks +import lightning.pytorch.loggers +import torch.nn -from ..utils.accelerator import AcceleratorProcessor from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from .callbacks import LoggingCallback logger = logging.getLogger(__name__) -def check_gpu(device): - """Check the device type and the availability of GPU. - - Parameters - ---------- - - device : :py:class:`torch.device` - device to use - """ - if device == "cuda": - # asserts we do have a GPU - assert bool( - gpu_constants() - ), f"Device set to '{device}', but nvidia-smi is not installed" - - -def save_model_summary(output_folder, model): +def save_model_summary( + output_folder: str, model: torch.nn.Module +) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: """Save a little summary of the model in a txt file. Parameters ---------- - output_folder : str + output_folder output path - model : :py:class:`torch.nn.Module` + model Network (e.g. driu, hed, unet) Returns ------- - r : str + summary: The model summary in a text format. - n : int + total_parameters: The number of parameters of the model. """ summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - summary = ModelSummary(model, max_depth=-1) + summary = lightning.pytorch.utilities.model_summary.ModelSummary( + model, max_depth=-1 + ) f.write(str(summary)) - return summary, ModelSummary(model).total_parameters + return ( + summary, + lightning.pytorch.utilities.model_summary.ModelSummary( + model + ).total_parameters, + ) -def static_information_to_csv(static_logfile_name, device, n): - """Save the static information in a csv file. +def static_information_to_csv( + static_logfile_name: str, + device_type: str, + model_size: int, +) -> None: + """Saves the static information in a CSV file. Parameters ---------- - static_logfile_name : str - The static file name which is a join between the output folder and "constant.csv" + static_logfile_name + The static file name which is a join between the output folder and + "constant.csv" + + device_type + The type of device we are using + + model_size + The size of the model we will be training """ if os.path.exists(static_logfile_name): backup = static_logfile_name + "~" @@ -78,80 +81,23 @@ def static_information_to_csv(static_logfile_name, device, n): os.unlink(backup) shutil.move(static_logfile_name, backup) with open(static_logfile_name, "w", newline="") as f: - logdata = cpu_constants() - if device == "cuda": - logdata += gpu_constants() - logdata += (("model_size", n),) - logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) + logdata: dict[str, int | float | str] = {} + logdata.update(cpu_constants()) + if device_type == "cuda": + results = gpu_constants() + if results is not None: + logdata.update(results) + logdata["model_size"] = model_size + logwriter = csv.DictWriter(f, fieldnames=logdata.keys()) logwriter.writeheader() - logwriter.writerow(dict(k for k in logdata)) - - -def check_exist_logfile(logfile_name, arguments): - """Check existance of logfile (trainlog.csv), If the logfile exist the and - the epochs number are still 0, The logfile will be replaced. - - Parameters - ---------- - - logfile_name : str - The logfile_name which is a join between the output_folder and trainlog.csv - - arguments : dict - start and end epochs - """ - if arguments["epoch"] == 0 and os.path.exists(logfile_name): - backup = logfile_name + "~" - if os.path.exists(backup): - os.unlink(backup) - shutil.move(logfile_name, backup) - - -def create_logfile_fields(valid_loader, extra_valid_loaders, device): - """Creation of the logfile fields that will appear in the logfile. - - Parameters - ---------- - - valid_loader : :py:class:`torch.utils.data.DataLoader` - To be used to validate the model and enable automatic checkpointing. - If set to ``None``, then do not validate it. - - extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` - To be used to validate the model, however **does not affect** automatic - checkpointing. If set to ``None``, or empty, then does not log anything - else. Otherwise, an extra column with the loss of every dataset in - this list is kept on the final training log. - - device : :py:class:`torch.device` - device to use - - Returns - ------- - - logfile_fields: tuple - The fields that will appear in trainlog.csv - """ - logfile_fields = ( - "epoch", - "total_time", - "eta", - "loss", - "learning_rate", - ) - if valid_loader is not None: - logfile_fields += ("validation_loss",) - if extra_valid_loaders: - logfile_fields += ("extra_validation_losses",) - logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda")) - return logfile_fields + logwriter.writerow(logdata) def run( model, datamodule, checkpoint_period, - accelerator, + device_manager, arguments, output_folder, monitoring_interval, @@ -187,8 +133,8 @@ def run( Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints. - accelerator : str - A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) + device_manager : DeviceManager + A device, to be used for training. arguments : dict Start and end epochs: @@ -210,30 +156,30 @@ def run( max_epoch = arguments["max_epoch"] - accelerator_processor = AcceleratorProcessor(accelerator) - os.makedirs(output_folder, exist_ok=True) # Save model summary - r, n = save_model_summary(output_folder, model) + _, no_of_parameters = save_model_summary(output_folder, model) - csv_logger = CSVLogger(output_folder, "logs_csv") - tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") + csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv") + tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger( + output_folder, "logs_tensorboard" + ) resource_monitor = ResourceMonitor( interval=monitoring_interval, - has_gpu=(accelerator_processor.accelerator == "gpu"), + has_gpu=device_manager.device_type == "cuda", main_pid=os.getpid(), logging_level=logging.ERROR, ) - checkpoint_callback = ModelCheckpoint( + checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( output_folder, "model_lowest_valid_loss", save_last=True, monitor="validation_loss", mode="min", - save_on_train_epoch_end=False, + save_on_train_epoch_end=True, every_n_epochs=checkpoint_period, ) @@ -242,17 +188,15 @@ def run( # write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") static_information_to_csv( - static_logfile_name, accelerator_processor.to_torch(), n + static_logfile_name, + device_manager.device_type, + no_of_parameters, ) - if accelerator_processor.device is None: - devices = "auto" - else: - devices = accelerator_processor.device - with resource_monitor: - trainer = Trainer( - accelerator=accelerator_processor.accelerator, + accelerator, devices = device_manager.lightning_accelerator() + trainer = lightning.pytorch.Trainer( + accelerator=accelerator, devices=devices, max_epochs=max_epoch, accumulate_grad_batches=batch_chunk_count, diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ba9bf05f7428d759489bd744f8ec35c3b43bab02..a878a076037925879c072bffb87d23a3e1ce7b0d 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -2,55 +2,185 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging +import typing + import lightning.pytorch as pl import torch -import torch.nn as nn +import torch.nn +import torch.nn.functional as F +import torch.optim.optimizer +import torch.utils.data import torchvision.models as models +import torchvision.transforms + +from ..data.typing import DataLoader, TransformSequence -from .normalizer import TorchVisionNormalizer +logger = logging.getLogger(__name__) class Alexnet(pl.LightningModule): """Alexnet module. Note: only usable with a normalized dataset + + Parameters + ---------- + + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + optimizer_type + The type of optimizer to use for training + + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + + pretrained + If set to True, loads pretrained model weights during initialization, else trains a new model. """ def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, - pretrained=False, + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], + pretrained: bool = False, ): super().__init__() - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + self.name = "alexnet" - self.name = "AlexNet" + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss + ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments - # Load pretrained model - weights = ( - None if pretrained is False else models.AlexNet_Weights.DEFAULT + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms ) - self.model_ft = models.alexnet(weights=weights) - self.normalizer = TorchVisionNormalizer(nb_channels=1) + self.pretrained = pretrained + + # Load pretrained model + if not pretrained: + weights = None + else: + logger.info(f"Loading pretrained {self.name} model weights") + weights = models.AlexNet_Weights.DEFAULT + + self.model_ft = models.alexnet(weights=weights) # Adapt output features - self.model_ft.classifier[4] = nn.Linear(4096, 512) - self.model_ft.classifier[6] = nn.Linear(512, 1) + self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) + self.model_ft.classifier[6] = torch.nn.Linear(512, 1) def forward(self, x): - x = self.normalizer(x) + x = self.normalizer(x) # type: ignore + x = self.model_ft(x) return x - def training_step(self, batch, batch_idx): - images = batch[1] - labels = batch[2] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initializes the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + + dataloader: :py:class:`torch.utils.data.DataLoader` + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from .normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision." + ) + self.normalizer = make_imagenet_normalizer() + else: + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." + ) + self.normalizer = make_z_normalizer(dataloader) + + def balance_losses_by_class( + self, train_dataloader: DataLoader, valid_dataloader: DataLoader + ): + """Reweights loss weights if possible. + + Parameters + ---------- + + train_dataloader + The data loader to use for training + + valid_dataloader + The data loader to use for validation + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. + """ + from .loss_weights import get_label_weights + + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") + weights = get_label_weights(train_dataloader) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") + weights = get_label_weights(valid_dataloader) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -58,17 +188,18 @@ class Alexnet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) - - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.float()) + augmented_images = [ + self._augmentation_transforms(img).to(self.device) for img in images + ] + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) + outputs = self(augmented_images) - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -78,34 +209,23 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - images = batch[1] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + return ( + names[0], + torch.flatten(probabilities), + torch.flatten(labels), + ) def configure_optimizers(self): - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments ) - - return optimizer diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index a7cf9d567946899c874efce1664d0e7f65d5ace2..8eba3b53410da4874d8b59c48c73e13bec1cb703 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -2,55 +2,187 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging +import typing + import lightning.pytorch as pl import torch -import torch.nn as nn +import torch.nn +import torch.nn.functional as F +import torch.optim.optimizer +import torch.utils.data import torchvision.models as models +import torchvision.transforms + +from ..data.typing import DataLoader, TransformSequence -from .normalizer import TorchVisionNormalizer +logger = logging.getLogger(__name__) class Densenet(pl.LightningModule): - """Densenet module. + """Densenet-121 module. + + Parameters + ---------- + + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + optimizer_type + The type of optimizer to use for training + + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. - Note: only usable with a normalized dataset + pretrained + If set to True, loads pretrained model weights during initialization, else trains a new model. """ def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, - pretrained=False, - nb_channels=3, + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], + pretrained: bool= False, ): super().__init__() - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + self.name = "densenet-121" - self.name = "Densenet" + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss + ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments + + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) - self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) + self.pretrained = pretrained # Load pretrained model - weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT + if not pretrained: + weights = None + else: + logger.info(f"Loading pretrained {self.name} model weights") + weights = models.DenseNet121_Weights.DEFAULT + self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = nn.Sequential( - nn.Linear(1024, 256), nn.Linear(256, 1) + self.model_ft.classifier = torch.nn.Sequential( + torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1) ) def forward(self, x): - x = self.normalizer(x) + + x = self.normalizer(x) # type: ignore + x = self.model_ft(x) return x - def training_step(self, batch, batch_idx): - images = batch[1] - labels = batch[2] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initializes the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + + dataloader: :py:class:`torch.utils.data.DataLoader` + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from .normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision." + ) + self.normalizer = make_imagenet_normalizer() + else: + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." + ) + self.normalizer = make_z_normalizer(dataloader) + + def balance_losses_by_class( + self, + train_dataloader: DataLoader, + valid_dataloader: dict[str, DataLoader], + ): + """Reweights loss weights if possible. + + Parameters + ---------- + + train_dataloader + The data loader to use for training + + valid_dataloader + The data loaders to use for each of the validation sets + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. + """ + from .loss_weights import get_label_weights + + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") + weights = get_label_weights(train_dataloader) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") + weights = get_label_weights(valid_dataloader) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -58,17 +190,18 @@ class Densenet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) - - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.float()) + augmented_images = [ + self._augmentation_transforms(img).to(self.device) for img in images + ] + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) + outputs = self(augmented_images) - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -78,35 +211,23 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.float()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - images = batch[1] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + return ( + names[0], + torch.flatten(probabilities), + torch.flatten(labels), + ) def configure_optimizers(self): - # Dynamically instantiates the optimizer given the configs - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments ) - - return optimizer diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..6889b2539fb1071e228c4487540ad9272aad808c --- /dev/null +++ b/src/ptbench/models/loss_weights.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging + +import torch +import torch.utils.data + +logger = logging.getLogger(__name__) + + +def get_label_weights( + dataloader: torch.utils.data.DataLoader, +) -> torch.Tensor: + """Computes the weights of each class of a DataLoader. + + This function inputs a pytorch DataLoader and computes the ratio between + number of negative and positive samples (scalar). The weight can be used + to adjust minimisation criteria to in cases there is a huge data imbalance. + + If + + It returns a vector with weights (inverse counts) for each label. + + + Parameters + ---------- + + dataloader + A DataLoader from which to compute the positive weights. Entries must + be a dictionary which must contain a ``label`` key. + + + Returns + ------- + + positive_weights + the positive weight of each class in the dataset given as input + """ + + targets = torch.tensor( + [sample for batch in dataloader for sample in batch[1]["label"]] + ) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]] + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index aa2142a6bb01de4dfc011cb8b49c1521b82a5e29..ce68f4b558b2812d17a8f54954e197a086233a52 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -2,37 +2,74 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""A network model that prefixes a z-normalization step to any other module.""" +"""A network model that prefixes a subtract/divide step to any other module.""" import torch import torch.nn +import torch.utils.data +import torchvision.transforms +import tqdm -class TorchVisionNormalizer(torch.nn.Module): - """A simple normalizer that applies the standard torchvision normalization. +def make_z_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> torchvision.transforms.Normalize: + """Computes mean and standard deviation from a dataloader. + + This function will input a dataloader, and compute the mean and standard + deviation by image channel. It will work for both monochromatic, and color + inputs with 2, 3 or more color planes. - This module does not learn. Parameters ---------- - nb_channels : :py:class:`int`, Optional - Number of images channels fed to the model + dataloader: + A torch Dataloader from which to compute the mean and std + + + Returns + ------- + An initialized normalizer + """ + + # Peek the number of channels of batches in the data loader + batch = next(iter(dataloader)) + channels = batch[0].shape[1] + + # Initialises accumulators + mean = torch.zeros(channels, dtype=batch[0].dtype) + var = torch.zeros(channels, dtype=batch[0].dtype) + num_images = 0 + + # Evaluates mean and standard deviation + for batch in tqdm.tqdm(dataloader, unit="batch"): + data = batch[0] + data = data.view(data.size(0), data.size(1), -1) + + num_images += data.size(0) + mean += data.mean(2).sum(0) + var += data.var(2).sum(0) + + mean /= num_images + var /= num_images + std = torch.sqrt(var) + + return torchvision.transforms.Normalize(mean, std) + + +def make_imagenet_normalizer() -> torchvision.transforms.Normalize: + """Returns the stock ImageNet normalisation weights from torchvision. + + The weights are wrapped in a torch module. This normalizer only works for + **RGB (color) images**. + + + Returns + ------- + An initialized normalizer """ - def __init__(self, nb_channels=3): - super().__init__() - mean = torch.zeros(nb_channels)[None, :, None, None] - std = torch.ones(nb_channels)[None, :, None, None] - self.register_buffer("mean", mean) - self.register_buffer("std", std) - self.name = "torchvision-normalizer" - - def set_mean_std(self, mean, std): - mean = torch.as_tensor(mean)[None, :, None, None] - std = torch.as_tensor(std)[None, :, None, None] - self.register_buffer("mean", mean) - self.register_buffer("std", std) - - def forward(self, inputs): - return inputs.sub(self.mean).div(self.std) + return torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 9da9702f31436df441cdb56d9415cc06e6829623..20bbb0dd9d06a122c81187762b8bbca0d04470fd 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -2,85 +2,143 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging +import typing + import lightning.pytorch as pl import torch -import torch.nn as nn +import torch.nn import torch.nn.functional as F +import torch.optim.optimizer +import torch.utils.data +import torchvision.transforms + +from ..data.typing import DataLoader, TransformSequence + +logger = logging.getLogger(__name__) + + +class Pasa(pl.LightningModule): + """Implementation of CNN by Pasa. + + Simple CNN for classification based on paper by [PASA-2019]_. + + + Parameters + ---------- + + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. -from .normalizer import TorchVisionNormalizer + .. warning:: + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. -class PASA(pl.LightningModule): - """PASA module. + optimizer_type + The type of optimizer to use for training - Based on paper by [PASA-2019]_. + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. """ def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], ): super().__init__() - self.save_hyperparameters() - self.name = "pasa" - self.normalizer = TorchVisionNormalizer(nb_channels=1) + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss + ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments + + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) # First convolution block - self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) - self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) - self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4)) + self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) + self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) + self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4)) - self.batchNorm2d_4 = nn.BatchNorm2d(4) - self.batchNorm2d_16 = nn.BatchNorm2d(16) - self.batchNorm2d_16_2 = nn.BatchNorm2d(16) + self.batchNorm2d_4 = torch.nn.BatchNorm2d(4) + self.batchNorm2d_16 = torch.nn.BatchNorm2d(16) + self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16) # Second convolution block - self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) - self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) - self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) + self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) + self.fc6 = torch.nn.Conv2d( + 16, 32, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_24 = nn.BatchNorm2d(24) - self.batchNorm2d_32 = nn.BatchNorm2d(32) - self.batchNorm2d_32_2 = nn.BatchNorm2d(32) + self.batchNorm2d_24 = torch.nn.BatchNorm2d(24) + self.batchNorm2d_32 = torch.nn.BatchNorm2d(32) + self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32) # Third convolution block - self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) - self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) - self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) + self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) + self.fc9 = torch.nn.Conv2d( + 32, 48, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_40 = nn.BatchNorm2d(40) - self.batchNorm2d_48 = nn.BatchNorm2d(48) - self.batchNorm2d_48_2 = nn.BatchNorm2d(48) + self.batchNorm2d_40 = torch.nn.BatchNorm2d(40) + self.batchNorm2d_48 = torch.nn.BatchNorm2d(48) + self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48) # Fourth convolution block - self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) - self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) - self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) + self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) + self.fc12 = torch.nn.Conv2d( + 48, 64, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_56 = nn.BatchNorm2d(56) - self.batchNorm2d_64 = nn.BatchNorm2d(64) - self.batchNorm2d_64_2 = nn.BatchNorm2d(64) + self.batchNorm2d_56 = torch.nn.BatchNorm2d(56) + self.batchNorm2d_64 = torch.nn.BatchNorm2d(64) + self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64) # Fifth convolution block - self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) - self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) - self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) + self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) + self.fc15 = torch.nn.Conv2d( + 64, 80, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_72 = nn.BatchNorm2d(72) - self.batchNorm2d_80 = nn.BatchNorm2d(80) - self.batchNorm2d_80_2 = nn.BatchNorm2d(80) + self.batchNorm2d_72 = torch.nn.BatchNorm2d(72) + self.batchNorm2d_80 = torch.nn.BatchNorm2d(80) + self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80) - self.pool2d = nn.MaxPool2d((3, 3), (2, 2)) # Pool after conv. block - self.dense = nn.Linear(80, 1) # Fully connected layer + self.pool2d = torch.nn.MaxPool2d( + (3, 3), (2, 2) + ) # Pool after conv. block + self.dense = torch.nn.Linear(80, 1) # Fully connected layer def forward(self, x): - x = self.normalizer(x) + x = self.normalizer(x) # type: ignore # First convolution block _x = x @@ -127,9 +185,70 @@ class PASA(pl.LightningModule): return x - def training_step(self, batch, batch_idx): - images = batch["data"] - labels = batch["label"] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initializes the input normalizer for the current model. + + Parameters + ---------- + + dataloader + A torch Dataloader from which to compute the mean and std + """ + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." + ) + self.normalizer = make_z_normalizer(dataloader) + + def balance_losses_by_class( + self, + train_dataloader: DataLoader, + valid_dataloader: dict[str, DataLoader], + ): + """Reweights loss weights if possible. + + Parameters + ---------- + + train_dataloader + The data loader to use for training + + valid_dataloader + The data loaders to use for each of the validation sets + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. + """ + from .loss_weights import get_label_weights + + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") + weights = get_label_weights(train_dataloader) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") + weights = get_label_weights(valid_dataloader) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -137,17 +256,18 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) + augmented_images = [ + self._augmentation_transforms(img).to(self.device) for img in images + ] + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) + outputs = self(augmented_images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.double()) - - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch["data"] - labels = batch["label"] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -157,63 +277,23 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch["name"] - images = batch["data"] - labels = batch["label"] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - results = ( + return ( names[0], torch.flatten(probabilities), torch.flatten(labels), ) - return results - # { - # f"dataloader_{dataloader_idx}_predictions": ( - # names[0], - # torch.flatten(probabilities), - # torch.flatten(labels), - # ) - # } - - # def on_predict_epoch_end(self): - - # retval = defaultdict(list) - - # for dataloader_name, predictions in self.predictions_cache.items(): - # for prediction in predictions: - # retval[dataloader_name]["name"].append(prediction[0]) - # retval[dataloader_name]["prediction"].append(prediction[1]) - # retval[dataloader_name]["label"].append(prediction[2]) - - # Need to cache predictions in the predict step, then reorder by key - # Clear prediction dict - # raise NotImplementedError - def configure_optimizers(self): - # Dynamically instantiates the optimizer given the configs - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments ) - - return optimizer diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 4d2a226b5b0b479b3f84c748d24aebd43d8d6dea..664b8b1ad1ae38625a22af4a1092b15da20b2727 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -6,9 +6,6 @@ import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup -from lightning.pytorch import seed_everything - -from ..utils.checkpointer import get_checkpoint logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -19,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") epilog="""Examples: \b - 1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``): + 1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): .. code:: sh @@ -39,47 +36,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="A torch.nn.Module instance implementing the network to be trained", + help="A lightining module instance implementing the network to be trained", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances implementing datasets " - "to be used for training and validating the model, possibly including all " - "pre-processing pipelines required or, optionally, a dictionary mapping " - "string keys to torch.utils.data.dataset.Dataset instances. At least " - "one key named ``train`` must be available. This dataset will be used for " - "training the network model. The dataset description must include all " - "required pre-processing, including eventual data augmentation. If a " - "dataset named ``__train__`` is available, it is used prioritarily for " - "training instead of ``train``. If a dataset named ``__valid__`` is " - "available, it is used for model validation (and automatic " - "check-pointing) at each epoch. If a dataset list named " - "``__extra_valid__`` is available, then it will be tracked during the " - "validation process and its loss output at the training log as well, " - "in the format of an array occupying a single column. All other keys " - "are considered test datasets and are ignored during training", - required=True, - cls=ResourceOption, -) -@click.option( - "--criterion", - help="A loss function to compute the CNN error for every sample " - "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", + help="A lighting data module containing the training and validation sets.", required=True, cls=ResourceOption, ) -@click.option( - "--criterion-valid", - help="A specific loss function for the validation set to compute the CNN" - "error for every sample respecting the PyTorch API for loss functions" - "(see torch.nn.modules.loss)", - required=False, - cls=ResourceOption, -) @click.option( "--batch-size", "-b", @@ -157,17 +124,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--accelerator", - "-a", - help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', + "--device", + "-d", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( - "--cache-samples", - help="If set to True, loads the sample into memory, otherwise loads them at runtime.", + "--cache-samples/--no-cache-samples", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", required=True, show_default=True, default=False, @@ -196,16 +164,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default=-1, cls=ResourceOption, ) -@click.option( - "--normalization", - "-n", - help="Z-Normalization of input images: 'imagenet' for ImageNet parameters," - " 'current' for parameters of the current trainset, " - "'none' for no normalization.", - required=False, - default="none", - cls=ResourceOption, -) @click.option( "--monitoring-interval", "-I", @@ -224,12 +182,25 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @click.option( "--resume-from", - help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", + help="Which checkpoint to resume training from. If set, can be one of " + "`best`, `last`, or a path to a model checkpoint.", type=str, required=False, default=None, cls=ResourceOption, ) +@click.option( + "--balance-classes/--no-balance-classes", + "-B/-N", + help="""If set, then balances weights of the random sampler during + training, so that samples from all sample classes are picked picked + equitably. It also sets the training (and validation) losses to account + for the populations of each class.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, @@ -238,20 +209,18 @@ def train( batch_size, batch_chunk_count, drop_incomplete_batch, - criterion, - criterion_valid, datamodule, checkpoint_period, - accelerator, + device, cache_samples, seed, parallel, - normalization, monitoring_interval, resume_from, + balance_classes, **_, -): - """Trains an CNN to perform tuberculosis detection. +) -> None: + """Trains an CNN to perform image classification. Training is performed for a configurable number of epochs, and generates at least a final_model.pth. It may also generate a number @@ -263,32 +232,56 @@ def train( import torch.cuda import torch.nn - from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss + from lightning.pytorch import seed_everything + + from ..engine.device import DeviceManager from ..engine.trainer import run + from ..utils.checkpointer import get_checkpoint + from .utils import save_sh_command + save_sh_command(output_folder) seed_everything(seed) checkpoint_file = get_checkpoint(output_folder, resume_from) - datamodule.update_module_properties( - batch_size=batch_size, - batch_chunk_count=batch_chunk_count, - drop_incomplete_batch=drop_incomplete_batch, - cache_samples=cache_samples, - parallel=parallel, - ) + # reset datamodule with user configurable options + datamodule.set_chunk_size(batch_size, batch_chunk_count) + datamodule.drop_incomplete_batch = drop_incomplete_batch + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel datamodule.prepare_data() datamodule.setup(stage="fit") - reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid) - normalize_data(normalization, model, datamodule) + # Sets the model normalizer with the unaugmented-train-subset. + # this call may be a NOOP, if the model was pre-trained and expects + # different weights for the normalisation layer. + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.unshuffled_train_dataloader()) + else: + logger.warning( + f"Model {model.name} has no 'set_normalizer' method. Skipping." + ) + + # If asked, rebalances the loss criterion based on the relative proportion + # of class examples available in the training set. Also affects the + # validation loss if a validation set is available on the data module. + if balance_classes: + logger.info("Applying datamodule train sampler balancing...") + datamodule.balance_sampler_by_class = True + # logger.info("Applying train/valid loss balancing...") + # model.balance_losses_by_class(datamodule) + else: + logger.info( + "Skipping sample class/dataset ownership balancing on user request" + ) arguments = {} arguments["max_epoch"] = epochs arguments["epoch"] = 0 - # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit() + # We only load the checkpoint to get some information about its state. The + # actual loading of the model is done in trainer.fit() if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) arguments["epoch"] = checkpoint["epoch"] @@ -300,7 +293,7 @@ def train( model=model, datamodule=datamodule, checkpoint_period=checkpoint_period, - accelerator=accelerator, + device_manager=DeviceManager(device), arguments=arguments, output_folder=output_folder, monitoring_interval=monitoring_interval, diff --git a/src/ptbench/scripts/utils.py b/src/ptbench/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5553a6ed1f4d79d1eba0ea6923c9821ce918dcd1 --- /dev/null +++ b/src/ptbench/scripts/utils.py @@ -0,0 +1,57 @@ +import importlib.metadata +import logging +import os +import pathlib +import sys +import time + +logger = logging.getLogger(__name__) + + +def save_sh_command(output_folder: str | pathlib.Path) -> None: + """Records command-line to reproduce this script. + + This function can record the current command-line used to call the script + being run. It creates an executable ``bash`` script setting up the current + working directory and activating a conda environment, if needed. It + records further information on the date and time the script was run and the + version of the package. + + + Parameters + ---------- + + output_folder : str + Path leading to the directory where the commands to reproduce the current + run will be recorded. A subdirectory will be created each time this function + is called to match lightning's versioning convention for loggers. + """ + + if isinstance(output_folder, str): + output_folder = pathlib.Path(output_folder) + + destfile = output_folder / "command.sh" + + logger.info(f"Writing command-line for reproduction at '{destfile}'...") + os.makedirs(output_folder, exist_ok=True) + + package = __name__.split(".", 1)[0] + version = importlib.metadata.version(package) + + with destfile.open("w") as f: + f.write("#!/usr/bin/env sh\n") + f.write(f"# date: {time.asctime()}\n") + f.write(f"# version: {version} ({package})\n") + f.write(f"# platform: {sys.platform}\n") + f.write("\n") + args = [] + for k in sys.argv: + if " " in k: + args.append(f'"{k}"') + else: + args.append(k) + if os.environ.get("CONDA_DEFAULT_ENV") is not None: + f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") + f.write(f"# cd {os.path.realpath(os.curdir)}\n") + f.write(" ".join(args) + "\n") + os.chmod(destfile, 0o755) diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py deleted file mode 100644 index dcfa2f733e1d091c5bb9a4e5785ee47f8e49497c..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/accelerator.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import logging -import os - -import torch - -logger = logging.getLogger(__name__) - - -class AcceleratorProcessor: - """This class is used to convert the torch device naming convention to - lightning's device convention and vice versa. - - It also sets the CUDA_VISIBLE_DEVICES if a gpu accelerator is used. - """ - - def __init__(self, name): - # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. - self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} - - self.lightning_to_torch = { - v: k for k, v in self.torch_to_lightning.items() - } - - self.valid_accelerators = set( - list(self.torch_to_lightning.keys()) - + list(self.lightning_to_torch.keys()) - ) - - self.accelerator, self.device = self._split_accelerator_name(name) - - if self.accelerator not in self.valid_accelerators: - raise ValueError(f"Unknown accelerator {self.accelerator}") - - # Keep lightning's convention by default - self.accelerator = self.to_lightning() - self.setup_accelerator() - - def setup_accelerator(self): - """If a gpu accelerator is chosen, checks the CUDA_VISIBLE_DEVICES - environment variable exists or sets its value if specified.""" - if self.accelerator == "gpu": - if not torch.cuda.is_available(): - raise RuntimeError( - f"CUDA is not currently available, but " - f"you set accelerator to '{self.accelerator}'" - ) - - if self.device is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device[0]) - else: - if os.environ.get("CUDA_VISIBLE_DEVICES") is None: - raise ValueError( - "Environment variable 'CUDA_VISIBLE_DEVICES' is not set." - "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0" - ) - else: - # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu - pass - - logger.info( - f"Accelerator set to {self.accelerator} and device to {self.device}" - ) - - def _split_accelerator_name(self, accelerator_name): - """Splits an accelerator string into accelerator and device components. - - Parameters - ---------- - - accelerator_name: str - The accelerator (or device in pytorch convention) string (e.g. cuda:0) - - Returns - ------- - - accelerator: str - The accelerator name - device: dict[int] - The selected devices - """ - - split_accelerator = accelerator_name.split(":") - accelerator = split_accelerator[0] - - if len(split_accelerator) > 1: - device = split_accelerator[1] - device = [int(device)] - else: - device = None - - return accelerator, device - - def to_torch(self): - """Converts the accelerator string to torch convention. - - Returns - ------- - - accelerator: str - The accelerator name in pytorch convention - """ - if self.accelerator in self.lightning_to_torch: - return self.lightning_to_torch[self.accelerator] - elif self.accelerator in self.torch_to_lightning: - return self.accelerator - else: - raise ValueError("Unknown accelerator.") - - def to_lightning(self): - """Converts the accelerator string to lightning convention. - - Returns - ------- - - accelerator: str - The accelerator name in lightning convention - """ - if self.accelerator in self.torch_to_lightning: - return self.torch_to_lightning[self.accelerator] - elif self.accelerator in self.lightning_to_torch: - return self.accelerator - else: - raise ValueError("Unknown accelerator.") diff --git a/src/ptbench/utils/image.py b/src/ptbench/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..363a8309f4581dfdda42124c0562d3bf840904af --- /dev/null +++ b/src/ptbench/utils/image.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os + +from typing import Union + +import torch + +from PIL.Image import Image +from torchvision import transforms + + +def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None: + """Saves a PIL image or a tensor as an image at the specified destination. + + Parameters + ---------- + + img: + A torch.Tensor or PIL.Image to save + + filepath: + The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified. + """ + + if isinstance(img, torch.Tensor): + img = transforms.ToPILImage()(img) + + root, ext = os.path.splitext(filepath) + + if len(ext) == 0: + filepath = filepath + ".png" + + img.save(filepath) diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py index ebad7794a6975932926ab2287801bb3cb052cd30..f7c7f6b958093a770bdc9c448b79e04a0b3dc704 100644 --- a/src/ptbench/utils/resources.py +++ b/src/ptbench/utils/resources.py @@ -6,11 +6,13 @@ import logging import multiprocessing +import multiprocessing.synchronize import os import queue import shutil import subprocess import time +import typing import numpy import psutil @@ -25,7 +27,9 @@ GB = float(2**30) """The number of bytes in a gigabyte.""" -def run_nvidia_smi(query, rename=None): +def run_nvidia_smi( + query: typing.Sequence[str], +) -> dict[str, str | float] | None: """Returns GPU information from query. For a comprehensive list of options and help, execute ``nvidia-smi @@ -35,52 +39,43 @@ def run_nvidia_smi(query, rename=None): Parameters ---------- - query : list + query A list of query strings as defined by ``nvidia-smi --help-query-gpu`` - rename : :py:class:`list`, Optional - A list of keys to yield in the return value for each entry above. It - gives you the opportunity to rewrite some key names for convenience. - This list, if provided, must be of the same length as ``query``. - Returns ------- - data : :py:class:`tuple`, None - An ordered dictionary (organized as 2-tuples) containing the queried - parameters (``rename`` versions). If ``nvidia-smi`` is not available, - returns ``None``. Percentage information is left alone, - memory information is transformed to gigabytes (floating-point). + data + A dictionary containing the queried parameters (``rename`` versions). + If ``nvidia-smi`` is not available, returns ``None``. Percentage + information is left alone, memory information is transformed to + gigabytes (floating-point). """ - if _nvidia_smi is not None: - if rename is None: - rename = query - else: - assert len(rename) == len(query) - - # Get GPU information based on GPU ID. - values = subprocess.getoutput( - "%s --query-gpu=%s --format=csv,noheader --id=%s" - % ( - _nvidia_smi, - ",".join(query), - os.environ.get("CUDA_VISIBLE_DEVICES"), - ) - ) - values = [k.strip() for k in values.split(",")] - t_values = [] - for k in values: - if k.endswith("%"): - t_values.append(float(k[:-1].strip())) - elif k.endswith("MiB"): - t_values.append(float(k[:-3].strip()) / 1024) - else: - t_values.append(k) # unchanged - return tuple(zip(rename, t_values)) - - -def gpu_constants(): + if _nvidia_smi is None: + return None + + # Gets GPU information, based on a GPU device if that is set. Returns + # ordered results. + query_str = ( + f"{_nvidia_smi} --query-gpu={','.join(query)} --format=csv,noheader" + ) + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_devices: + query_str += f" --id={visible_devices}" + values = subprocess.getoutput(query_str) + + retval: dict[str, str | float] = {} + for i, k in enumerate([k.strip() for k in values.split(",")]): + retval[query[i]] = k + if k.endswith("%"): + retval[query[i]] = float(k[:-1].strip()) + elif k.endswith("MiB"): + retval[query[i]] = float(k[:-3].strip()) / 1024 + return retval + + +def gpu_constants() -> dict[str, str | int | float] | None: """Returns GPU (static) information using nvidia-smi. See :py:func:`run_nvidia_smi` for operational details. @@ -90,21 +85,25 @@ def gpu_constants(): data : :py:class:`tuple`, None If ``nvidia-smi`` is not available, returns ``None``, otherwise, we - return an ordered dictionary (organized as 2-tuples) containing the - following ``nvidia-smi`` query information: + return a dictionary containing the following ``nvidia-smi`` query + information, in this order: * ``gpu_name``, as ``gpu_name`` (:py:class:`str`) * ``driver_version``, as ``gpu_driver_version`` (:py:class:`str`) * ``memory.total``, as ``gpu_memory_total`` (transformed to gigabytes, :py:class:`float`) """ - return run_nvidia_smi( - ("gpu_name", "driver_version", "memory.total"), - ("gpu_name", "gpu_driver_version", "gpu_memory_total_GB"), - ) + retval = run_nvidia_smi(("gpu_name", "driver_version", "memory.total")) + if retval is None: + return retval + + # else, just update with more generic names + retval["gpu_driver_version"] = retval.pop("driver_version") + retval["gpu_memory_used_GB"] = retval.pop("memory.total") + return retval -def gpu_log(): +def gpu_log() -> dict[str, float] | None: """Returns GPU information about current non-static status using nvidia- smi. @@ -113,10 +112,10 @@ def gpu_log(): Returns ------- - data : :py:class:`tuple`, None + data If ``nvidia-smi`` is not available, returns ``None``, otherwise, we - return an ordered dictionary (organized as 2-tuples) containing the - following ``nvidia-smi`` query information: + return a dictionary containing the following ``nvidia-smi`` query + information, in this order: * ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes, :py:class:`float`) @@ -127,47 +126,41 @@ def gpu_log(): * ``utilization.gpu``, as ``gpu_percent``, (:py:class:`float`, in percent) """ - retval = run_nvidia_smi( - ( - "memory.total", - "memory.used", - "memory.free", - "utilization.gpu", - ), - ( - "gpu_memory_total_GB", - "gpu_memory_used_GB", - "gpu_memory_free_percent", - "gpu_usage_percent", - ), - ) - # re-compose the output to generate expected values - return ( - retval[1], # gpu_memory_used - retval[2], # gpu_memory_free - ("gpu_memory_percent", 100 * (retval[1][1] / retval[0][1])), - retval[3], # gpu_percent + result = run_nvidia_smi( + ("memory.total", "memory.used", "memory.free", "utilization.gpu") ) + if result is None: + return result -def cpu_constants(): + return { + "gpu_memory_used_GB": float(result["memory.used"]), + "gpu_memory_free_GB": float(result["memory.free"]), + "gpu_memory_percent": 100 + * float(result["memory.used"]) + / float(result["memory.total"]), + "gpu_percent": float(result["utilization.gpu"]), + } + + +def cpu_constants() -> dict[str, int | float]: """Returns static CPU information about the current system. Returns ------- - data : tuple + data An ordered dictionary (organized as 2-tuples) containing these entries: 0. ``cpu_memory_total`` (:py:class:`float`): total memory available, in gigabytes 1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available """ - return ( - ("cpu_memory_total_GB", psutil.virtual_memory().total / GB), - ("cpu_count", psutil.cpu_count(logical=True)), - ) + return { + "cpu_memory_total_GB": psutil.virtual_memory().total / GB, + "cpu_count": psutil.cpu_count(logical=True), + } class CPULogger: @@ -176,24 +169,24 @@ class CPULogger: Parameters ---------- - pid : :py:class:`int`, Optional + pid Process identifier of the main process (parent process) to observe """ - def __init__(self, pid=None): + def __init__(self, pid: int | None = None): this = psutil.Process(pid=pid) self.cluster = [this] + this.children(recursive=True) # touch cpu_percent() at least once for all processes in the cluster [k.cpu_percent(interval=None) for k in self.cluster] - def log(self): - """Returns current process cluster information. + def log(self) -> dict[str, int | float]: + """Returns current process cluster iformation. Returns ------- - data : tuple - An ordered dictionary (organized as 2-tuples) containing these entries: + data + An ordered dictionary containing these entries: 0. ``cpu_memory_used`` (:py:class:`float`): total memory used from the system, in gigabytes @@ -244,14 +237,14 @@ class CPULogger: # it is too late to update any intermediate list # at this point, but ensures to update counts later on gone.add(k) - return ( - ("cpu_memory_used_GB", psutil.virtual_memory().used / GB), - ("cpu_rss_GB", sum([k.rss for k in memory_info]) / GB), - ("cpu_vms_GB", sum([k.vms for k in memory_info]) / GB), - ("cpu_percent", sum(cpu_percent)), - ("cpu_processes", len(self.cluster) - len(gone)), - ("cpu_open_files", sum(open_files)), - ) + return { + "cpu_memory_used_GB": psutil.virtual_memory().used / GB, + "cpu_rss_GB": sum([k.rss for k in memory_info]) / GB, + "cpu_vms_GB": sum([k.vms for k in memory_info]) / GB, + "cpu_percent": sum(cpu_percent), + "cpu_processes": len(self.cluster) - len(gone), + "cpu_open_files": sum(open_files), + } class _InformationGatherer: @@ -260,73 +253,85 @@ class _InformationGatherer: Parameters ---------- - has_gpu : bool + has_gpu A flag indicating if we have a GPU installed on the platform or not - main_pid : int + main_pid The main process identifier to monitor - logger : logging.Logger + logger A logger to be used for logging messages """ - def __init__(self, has_gpu, main_pid, logger): + def __init__( + self, has_gpu: bool, main_pid: int | None, logger: logging.Logger + ): + self.logger: logging.Logger = logger self.cpu_logger = CPULogger(main_pid) - self.keys = [k[0] for k in self.cpu_logger.log()] - self.cpu_keys_len = len(self.keys) - self.has_gpu = has_gpu - self.logger = logger + keys: list[str] = list(self.cpu_logger.log().keys()) + self.has_gpu: bool = has_gpu if self.has_gpu: - self.keys += [k[0] for k in gpu_log()] - self.data = [[] for _ in self.keys] + example = gpu_log() + if example is not None: + keys += list(example.keys()) + self.data: dict[str, list[int | float]] = {k: [] for k in keys} - def acc(self): + def acc(self) -> None: """Accumulates another measurement.""" - for i, k in enumerate(self.cpu_logger.log()): - self.data[i].append(k[1]) + for k, v in self.cpu_logger.log().items(): + self.data[k].append(v) if self.has_gpu: - for i, k in enumerate(gpu_log()): - self.data[i + self.cpu_keys_len].append(k[1]) + sample = gpu_log() + if sample is not None: + for k, v in sample.items(): + self.data[k].append(v) - def clear(self): + def clear(self) -> None: """Clears accumulated data.""" - self.data = [[] for _ in self.keys] + for k in self.data.keys(): + self.data[k] = [] - def summary(self): + def summary(self) -> dict[str, list[int | float]]: """Returns the current data.""" - if len(self.data[0]) == 0: + if len(next(iter(self.data.values()))) == 0: self.logger.error("CPU/GPU logger was not able to collect any data") - retval = [] - for k, values in zip(self.keys, self.data): - retval.append((k, values)) - return tuple(retval) + return self.data def _monitor_worker( - interval, has_gpu, main_pid, stop, summary_event, queue, logging_level + interval: int | float, + has_gpu: bool, + main_pid: int, + stop: multiprocessing.synchronize.Event, + summary_event: multiprocessing.synchronize.Event, + queue: queue.Queue, + logging_level: int, ): """A monitoring worker that measures resources and returns lists. Parameters ========== - interval : int, float + interval Number of seconds to wait between each measurement (maybe a floating point number as accepted by :py:func:`time.sleep`) - has_gpu : bool + has_gpu A flag indicating if we have a GPU installed on the platform or not - main_pid : int + main_pid The main process identifier to monitor - stop : :py:class:`multiprocessing.Event` - Indicates if we should continue running or stop + stop + Event that indicates if we should continue running or stop - queue : :py:class:`queue.Queue` + summary_event + Event that indicates if we should produce a summary + + queue A queue, to send monitoring information back to the spawner - logging_level: int + logging_level The logging level to use for logging from launched processes """ logger = multiprocessing.log_to_stderr(level=logging_level) @@ -343,9 +348,9 @@ def _monitor_worker( time.sleep(interval) except Exception: - logger.warning( - "Iterative CPU/GPU logging did not work properly " "this once", - exc_info=True, + logger.exception( + "Iterative CPU/GPU logging did not work properly." + " Exception follows. Retrying..." ) time.sleep(0.5) # wait half a second, and try again! @@ -356,27 +361,35 @@ class ResourceMonitor: Parameters ---------- - interval : int, float + interval Number of seconds to wait between each measurement (maybe a floating point number as accepted by :py:func:`time.sleep`) - has_gpu : bool + has_gpu A flag indicating if we have a GPU installed on the platform or not - main_pid : int + main_pid The main process identifier to monitor - logging_level: int + logging_level The logging level to use for logging from launched processes """ - def __init__(self, interval, has_gpu, main_pid, logging_level): + def __init__( + self, + interval: int | float, + has_gpu: bool, + main_pid: int, + logging_level: int, + ): self.interval = interval self.has_gpu = has_gpu self.main_pid = main_pid self.stop_event = multiprocessing.Event() self.summary_event = multiprocessing.Event() - self.q = multiprocessing.Queue() + self.q: multiprocessing.Queue[ + dict[str, list[int | float]] + ] = multiprocessing.Queue() self.logging_level = logging_level self.monitor = multiprocessing.Process( @@ -393,23 +406,23 @@ class ResourceMonitor: ), ) - self.data = None - - @staticmethod - def monitored_keys(has_gpu): - return _InformationGatherer(has_gpu, None, logger).keys + self.data: dict[str, int | float] | None = None - def __enter__(self): + def __enter__(self) -> None: """Starts the monitoring process.""" self.monitor.start() - def trigger_summary(self): + def checkpoint(self) -> None: + """Forces the monitoring process to yield data and clear the internal + accumlator.""" self.summary_event.set() try: - data = self.q.get(timeout=2 * self.interval) + data: dict[str, list[int | float]] = self.q.get( + timeout=2 * self.interval + ) except queue.Empty: - logger.warn( + logger.warning( f"CPU/GPU resource monitor did not provide anything when " f"joined (even after a {2*self.interval}-second timeout - " f"this is normally due to exceptions on the monitoring process. " @@ -417,19 +430,18 @@ class ResourceMonitor: ) self.data = None else: - # summarize the returned data by creating means - summary = [] - for k, values in data: + # summarize the returned data by creating averages + self.data = {} + for k, values in data.items(): if values: if k in ("cpu_processes", "cpu_open_files"): - summary.append((k, numpy.max(values))) + self.data[k] = numpy.max(values) else: - summary.append((k, numpy.mean(values))) + self.data[k] = float(numpy.mean(values)) else: - summary.append((k, 0.0)) - self.data = tuple(summary) + self.data[k] = 0.0 - def __exit__(self, *exc): + def __exit__(self, *_) -> None: """Stops the monitoring process and returns the summary of observations."""