From c464ae3eed032bd5e90bf9673720958bd180cdaf Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 3 Jul 2023 13:10:29 +0200 Subject: [PATCH] Updates to the data module and models --- src/ptbench/data/datamodule.py | 133 ++++++++++++++++++--------------- src/ptbench/engine/trainer.py | 47 +++++++----- src/ptbench/models/densenet.py | 15 ++++ src/ptbench/models/pasa.py | 10 +++ src/ptbench/scripts/train.py | 61 ++++++--------- 5 files changed, 146 insertions(+), 120 deletions(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 633ebd62..98e003c5 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import collections +import logging import multiprocessing import sys import typing @@ -12,6 +13,8 @@ import torch import torch.utils.data import torchvision.transforms +logger = logging.getLogger(__name__) + def _setup_dataloader_multiproc_parameters( parallel: int, @@ -80,7 +83,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): interval ``[0, 1]``, and a dictionary with further metadata attributes. transforms - A set of transforms that should be applied on-the-fly for this dataset. + A set of transforms that should be applied on-the-fly for this dataset, + to fit the output of the raw-data-loader to the model of interest. """ def __init__( @@ -113,7 +117,8 @@ class _CachedDataset(torch.utils.data.Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the split during construction - instead of delaying that to the indexing. + instead of delaying that to the indexing. Beyong raw-data-loading, + ``transforms`` given upon construction contribute to the cached samples. Parameters @@ -130,7 +135,9 @@ class _CachedDataset(torch.utils.data.Dataset): interval ``[0, 1]``, and a dictionary with further metadata attributes. transforms - A set of transforms that should be applied on-the-fly for this dataset. + A set of transforms that should be applied to the cached samples for + this dataset, to fit the output of the raw-data-loader to the model of + interest. """ def __init__( @@ -143,8 +150,8 @@ class _CachedDataset(torch.utils.data.Dataset): typing.Callable[[torch.Tensor], torch.Tensor] ] = [], ): - self.data = [raw_data_loader(k) for k in split] self.transform = torchvision.transforms.Compose(*transforms) + self.data = [raw_data_loader(k) for k in split] def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.data[key] @@ -286,21 +293,13 @@ class CachingDataModule(lightning.LightningDataModule): particularly important in highly unbalanced datasets. The function :py:func:`get_sample_weights` may help you in this aspect. - data_augmentations - A list of torchvision transforms (torch modules) that will be applied - on training set samples to create data augmentations during the - training of a model. Augmentation transform pipelines are applied - *after* the raw data is loaded, and before ``model_transforms``. - Augmentation transforms assume they receive a torch tensor representing - an image as input (see :py:class:`torchvision.transforms.ToTensor` for - details), in the range ``[0, 1]``. - model_transforms - A list of torchvision transforms (torch modules) that will be applied - after data augmentation transforms, and just before data is fed into - the model for all data loaders produced by this data module. This part - of the pipeline receives data as output by the raw-data-loader, or from - data augmentations, if any is specified. + A list of transforms (torch modules) that will be applied after + raw-data-loading, and just before data is fed into the model or + eventual data-augmentation transformations for all data loaders + produced by this data module. This part of the pipeline receives data + as output by the raw-data-loader, or model-related transforms (e.g. + resize adaptions), if any is specified. batch_size Number of samples in every **training** batch (this parameter affects @@ -348,9 +347,6 @@ class CachingDataModule(lightning.LightningDataModule): ], cache_samples: bool = False, train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - data_augmentations: list[ - typing.Callable[[torch.Tensor], torch.Tensor] - ] = [], model_transforms: list[ typing.Callable[[torch.Tensor], torch.Tensor] ] = [], @@ -367,7 +363,6 @@ class CachingDataModule(lightning.LightningDataModule): self.raw_data_loader = raw_data_loader self.cache_samples = cache_samples self.train_sampler = train_sampler - self.data_augmentations = data_augmentations self.model_transforms = model_transforms self.drop_incomplete_batch = drop_incomplete_batch @@ -437,10 +432,36 @@ class CachingDataModule(lightning.LightningDataModule): self._batch_chunk_count = batch_chunk_count self._chunk_size = self._batch_size // self._batch_chunk_count + def _setup_dataset(self, name: str) -> None: + """Sets-up a single dataset from the input data split. + + Parameters + ---------- + + name + Name of the dataset to setup. + """ + + if name in self._datasets: + logger.info(f"Dataset {name} is already setup. Not reloading it.") + return + if self.cache_samples: + self._datasets[name] = _CachedDataset( + self.database_split[name], + self.raw_data_loader, + self.model_transforms, + ) + else: + self._datasets[name] = _DelayedLoadingDataset( + self.database_split[name], + self.raw_data_loader, + self.model_transforms, + ) + def setup(self, stage: str) -> None: """Sets up datasets for different tasks on the pipeline. - This method should setup (load, pre-process, etc) all subsets required + This method should setup (load, pre-process, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict), and keep them ready to be used on one of the `_dataloader()` functions that are pertinent for such stage. @@ -463,62 +484,50 @@ class CachingDataModule(lightning.LightningDataModule): * ``predict``: uses only the test dataset """ - def _setup(name, transforms): - if self.cache_samples: - self._datasets[name] = _CachedDataset( - self.database_split[name], self.raw_data_loader, transforms - ) - else: - self._datasets[name] = _DelayedLoadingDataset( - self.database_split[name], self.raw_data_loader, transforms - ) - if stage == "fit": - _setup("train", self.data_augmentations + self.model_transforms) - _setup("validation", self.model_transforms) + self._setup_dataset("train") + self._setup_dataset("validation") for k in self.database_split: if k.startswith("monitor-"): - _setup(k, self.model_transforms) + self._setup_dataset(k) elif stage == "validate": - _setup("validation", self.model_transforms) + self._setup_dataset("validation") for k in self.database_split: if k.startswith("monitor-"): - _setup(k, self.model_transforms) + self._setup_dataset(k) elif stage == "test": - _setup("test", self.model_transforms) + self._setup_dataset("test") elif stage == "predict": - _setup("test", self.model_transforms) + self._setup_dataset("test") + + def teardown(self, stage: str) -> None: + """Unset-up datasets for different tasks on the pipeline. + + This method unsets (unload, remove from memory, etc) all datasets required + for a particular ``stage`` (fit, validate, test, predict). - def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader: - """Returns a version of the train dataloader without augmentations. + If you have set ``cache_samples``, samples are loaded, this may + effectivley release all the associated memory. - Use this method to obtain a version of the train dataloader without - augmentations, to compute input normalisation factors (e.g. mean and - standard deviation or min-max parameterisations). + Parameters + ---------- - Returns - ------- + stage + Name of the stage to which the teardown is applicable. Can be one of + ``fit``, ``validate``, ``test`` or ``predict``. Each stage + typically uses the following data loaders: - dataloader - The unaugmented train dataloader + * ``fit``: uses both train and validation datasets + * ``validate``: uses only the validation dataset + * ``test``: uses only the test dataset + * ``predict``: uses only the test dataset """ - dataset = _DelayedLoadingDataset( - self.database_split["train"], - self.raw_data_loader, - self.model_transforms, - ) - return torch.utils.data.DataLoader( - dataset, - shuffle=False, - batch_size=self._chunk_size, - drop_last=self.drop_incomplete_batch, - pin_memory=self.pin_memory, - **self._dataloader_multiproc, - ) + + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} def train_dataloader(self) -> torch.utils.data.DataLoader: """Returns the train data loader.""" diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 2c8bdc55..6b156f86 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -7,10 +7,10 @@ import logging import os import shutil -from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger -from lightning.pytorch.utilities.model_summary import ModelSummary +import lightning.pytorch +import lightning.pytorch.callbacks +import lightning.pytorch.loggers +import torch.nn from ..utils.accelerator import AcceleratorProcessor from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants @@ -19,7 +19,7 @@ from .callbacks import LoggingCallback logger = logging.getLogger(__name__) -def check_gpu(device): +def check_gpu(device: str) -> None: """Check the device type and the availability of GPU. Parameters @@ -35,32 +35,37 @@ def check_gpu(device): ), f"Device set to '{device}', but nvidia-smi is not installed" -def save_model_summary(output_folder, model): +def save_model_summary( + output_folder: str, model: torch.nn.Module +) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]: """Save a little summary of the model in a txt file. Parameters ---------- - output_folder : str + output_folder output path - model : :py:class:`torch.nn.Module` + model Network (e.g. driu, hed, unet) Returns ------- - r : str + r The model summary in a text format. - n : int + n The number of parameters of the model. """ summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - summary = ModelSummary(model, max_depth=-1) + summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1) f.write(str(summary)) - return summary, ModelSummary(model).total_parameters + return ( + summary, + lightning.pytorch.callbacks.ModelSummary(model).total_parameters, + ) def static_information_to_csv(static_logfile_name, device, n): @@ -70,7 +75,8 @@ def static_information_to_csv(static_logfile_name, device, n): ---------- static_logfile_name : str - The static file name which is a join between the output folder and "constant.csv" + The static file name which is a join between the output folder and + "constant.csv" """ if os.path.exists(static_logfile_name): backup = static_logfile_name + "~" @@ -188,7 +194,8 @@ def run( not save intermediary checkpoints. accelerator : str - A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The + device can also be specified (gpu:0) arguments : dict Start and end epochs: @@ -215,10 +222,12 @@ def run( os.makedirs(output_folder, exist_ok=True) # Save model summary - r, n = save_model_summary(output_folder, model) + _, n = save_model_summary(output_folder, model) - csv_logger = CSVLogger(output_folder, "logs_csv") - tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") + csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv") + tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger( + output_folder, "logs_tensorboard" + ) resource_monitor = ResourceMonitor( interval=monitoring_interval, @@ -227,7 +236,7 @@ def run( logging_level=logging.ERROR, ) - checkpoint_callback = ModelCheckpoint( + checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( output_folder, "model_lowest_valid_loss", save_last=True, @@ -251,7 +260,7 @@ def run( devices = accelerator_processor.device with resource_monitor: - trainer = Trainer( + trainer = lightning.pytorch.Trainer( accelerator=accelerator_processor.accelerator, devices=devices, max_epochs=max_epoch, diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 1aa1b268..e61d7dec 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -2,11 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging + import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models +logger = logging.getLogger(__name__) + class Densenet(pl.LightningModule): """Densenet module. @@ -25,6 +29,8 @@ class Densenet(pl.LightningModule): ): super().__init__() + # Saves all hyper parameters declared on __init__ into ``self.hparams`. + # You can access those by their name, like `self.hparams.optimizer` self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) self.name = "Densenet" @@ -55,10 +61,19 @@ class Densenet(pl.LightningModule): if self.pretrained: from .normalizer import make_imagenet_normalizer + logger.warning( + "ImageNet pre-trained densenet model - NOT" + "computing z-norm factors from training data. " + "Using preset factors from torchvision." + ) self.normalizer = make_imagenet_normalizer() else: from .normalizer import make_z_normalizer + logger.info( + "Uninitialised densenet model - " + "computing z-norm factors from training data." + ) self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, batch_idx): diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 61105dbe..fbc73f81 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging + import lightning.pytorch as pl import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data +logger = logging.getLogger(__name__) + class PASA(pl.LightningModule): """PASA module. @@ -24,6 +28,8 @@ class PASA(pl.LightningModule): ): super().__init__() + # Saves all hyper parameters declared on __init__ into ``self.hparams`. + # You can access those by their name, like `self.hparams.criterion` self.save_hyperparameters() self.name = "pasa" @@ -137,6 +143,10 @@ class PASA(pl.LightningModule): """ from .normalizer import make_z_normalizer + logger.info( + "Uninitialised densenet model - " + "computing z-norm factors from training data." + ) self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, _): diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index eeb5b869..01f294d7 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -6,9 +6,6 @@ import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup -from lightning.pytorch import seed_everything - -from ..utils.checkpointer import get_checkpoint logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -65,21 +62,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") required=True, cls=ResourceOption, ) -@click.option( - "--criterion", - help="A loss function to compute the CNN error for every sample " - "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", - required=True, - cls=ResourceOption, -) -@click.option( - "--criterion-valid", - help="A specific loss function for the validation set to compute the CNN" - "error for every sample respecting the PyTorch API for loss functions" - "(see torch.nn.modules.loss)", - required=False, - cls=ResourceOption, -) @click.option( "--batch-size", "-b", @@ -159,7 +141,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--accelerator", "-a", - help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). ' + "The device can also be specified (gpu:0)", show_default=True, required=True, default="cpu", @@ -167,7 +150,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @click.option( "--cache-samples", - help="If set to True, loads the sample into memory, otherwise loads them at runtime.", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", required=True, show_default=True, default=False, @@ -196,16 +180,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default=-1, cls=ResourceOption, ) -@click.option( - "--normalization", - "-n", - help="Z-Normalization of input images: 'imagenet' for ImageNet parameters," - " 'current' for parameters of the current trainset, " - "'none' for no normalization.", - required=False, - default="none", - cls=ResourceOption, -) @click.option( "--monitoring-interval", "-I", @@ -224,7 +198,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @click.option( "--resume-from", - help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", + help="Which checkpoint to resume training from. If set, can be one of " + "`best`, `last`, or a path to a model checkpoint.", type=str, required=False, default=None, @@ -238,20 +213,16 @@ def train( batch_size, batch_chunk_count, drop_incomplete_batch, - criterion, - criterion_valid, datamodule, checkpoint_period, accelerator, cache_samples, seed, parallel, - normalization, monitoring_interval, resume_from, - **_, ): - """Trains an CNN to perform tuberculosis detection. + """Trains an CNN to perform image classification. Training is performed for a configurable number of epochs, and generates at least a final_model.pth. It may also generate a number @@ -263,21 +234,33 @@ def train( import torch.cuda import torch.nn - from ..data.dataset import reweight_BCEWithLogitsLoss + from lightning.pytorch import seed_everything + from ..engine.trainer import run + from ..utils.checkpointer import get_checkpoint seed_everything(seed) checkpoint_file = get_checkpoint(output_folder, resume_from) + # reset datamodule with user configurable options datamodule.set_chunk_size(batch_size, batch_chunk_count) + datamodule.drop_incomplete_batch = drop_incomplete_batch + datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.prepare_data() datamodule.setup(stage="fit") - reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid) - model.set_normalizer(datamodule.unaugmented_train_dataloader()) + # Sets the model normalizer with the unaugmented-train-subset. + # this call may be a NOOP, if the model was pre-trained and expects + # different weights for the normalisation layer. + model.set_normalizer(datamodule.train_dataloader()) + + # Rebalances the loss criterion based on the relative proportion of class + # examples available in the training set. Also affects the validation loss + # if a validation set is available on the data module. + model.set_bce_loss_weights(datamodule) arguments = {} arguments["max_epoch"] = epochs -- GitLab