diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 635c9ed9b7ed23a255ba9f5e5f79828de17b7f17..b2d3ab1d48a9bd0d3a66e4a66c50fcc2300be2c9 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -10,6 +10,8 @@ * See :py:mod:`ptbench.data.shenzhen` for dataset details """ +from ....data.shenzhen.datamodule import ShenzhenDataModule from . import _maker dataset = _maker("default") +datamodule = ShenzhenDataModule diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/shenzhen/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..60de2efb1fa1d79be5a01054578c3718904f1d2b --- /dev/null +++ b/src/ptbench/data/shenzhen/datamodule.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import lightning as pl +import torch + +from clapper.logging import setup +from torch.utils.data import DataLoader, WeightedRandomSampler + +from ptbench.configs.datasets import get_samples_weights + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class ShenzhenDataModule(pl.LightningDataModule): + def __init__( + self, + dataset, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + multiproc_kwargs=None, + ): + super().__init__() + + self.dataset = dataset + + self.train_batch_size = train_batch_size + self.predict_batch_size = predict_batch_size + + self.drop_incomplete_batch = drop_incomplete_batch + self.pin_memory = ( + torch.cuda.is_available() + ) # should only be true if GPU available and using it + + self.multiproc_kwargs = multiproc_kwargs + + def setup(self, stage: str): + if stage == "fit": + if "__train__" in self.dataset: + logger.info("Found (dedicated) '__train__' set for training") + self.train_dataset = self.dataset["__train__"] + else: + self.train_dataset = self.dataset["train"] + + if "__valid__" in self.dataset: + logger.info("Found (dedicated) '__valid__' set for validation") + self.validation_dataset = self.dataset["__valid__"] + + if "__extra_valid__" in self.dataset: + if not isinstance(self.dataset["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(self.dataset['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(self.dataset['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + self.extra_validation_datasets = self.dataset["__extra_valid__"] + else: + self.extra_validation_datasets = None + + if stage == "predict": + self.predict_dataset = [] + + for split_key in self.dataset.keys(): + if split_key.startswith("_"): + logger.info( + f"Skipping dataset '{split_key}' (not to be evaluated)" + ) + continue + + else: + self.predict_dataset.append(self.dataset[split_key]) + + def train_dataloader(self): + train_samples_weights = get_samples_weights(self.train_dataset) + + train_sampler = WeightedRandomSampler( + train_samples_weights, len(train_samples_weights), replacement=True + ) + + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + sampler=train_sampler, + **self.multiproc_kwargs, + ) + + def val_dataloader(self): + loaders_dict = {} + + val_loader = DataLoader( + dataset=self.validation_dataset, + batch_size=self.train_batch_size, + shuffle=False, + drop_last=False, + pin_memory=self.pin_memory, + **self.multiproc_kwargs, + ) + + loaders_dict["validation_loader"] = val_loader + + if self.extra_validation_datasets is not None: + for set_idx, extra_set in enumerate(self.extra_validation_datasets): + extra_val_loader = DataLoader( + dataset=extra_set, + batch_size=self.train_batch_size, + shuffle=False, + drop_last=False, + pin_memory=self.pin_memory, + **self.multiproc_kwargs, + ) + + loaders_dict[ + f"extra_validation_loader{set_idx}" + ] = extra_val_loader + + return loaders_dict + + def predict_dataloader(self): + loaders_dict = {} + + for set_idx, pred_set in enumerate(self.predict_dataset): + loaders_dict[set_idx] = DataLoader( + dataset=pred_set, + batch_size=self.predict_batch_size, + shuffle=False, + pin_memory=self.pin_memory, + ) + + return loaders_dict diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index b266ae6221cf9a925ff941f1c99bdfdd044fa23f..962a761cce0b2a82bd6f41a60c92c49333f59da8 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -1,6 +1,8 @@ import csv import time +from collections import defaultdict + import numpy from lightning.pytorch import Callback @@ -17,6 +19,7 @@ class LoggingCallback(Callback): super().__init__() self.training_loss = [] self.validation_loss = [] + self.extra_validation_loss = defaultdict(list) self.start_training_time = 0 self.start_epoch_time = 0 @@ -37,6 +40,13 @@ class LoggingCallback(Callback): ): self.validation_loss.append(outputs["validation_loss"].item()) + if len(outputs) > 1: + extra_validation_keys = outputs.keys().remove("validation_loss") + for extra_validation_loss_key in extra_validation_keys: + self.extra_validation_loss[extra_validation_loss_key].append( + outputs[extra_validation_loss_key] + ) + def on_validation_epoch_end(self, trainer, pl_module): self.resource_monitor.trigger_summary() @@ -52,6 +62,15 @@ class LoggingCallback(Callback): self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"]) self.log("validation_loss", numpy.average(self.validation_loss)) + if len(self.extra_validation_loss) > 0: + for ( + extra_valid_loss_key, + extra_valid_loss_values, + ) in self.extra_validation_loss.items: + self.log( + extra_valid_loss_key, numpy.average(extra_valid_loss_values) + ) + queue_retries = 0 # In case the resource monitor takes longer to fetch data from the queue, we wait # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index a85a3da566691922323fbeb8a56d3472def48389..2c8bdc55f161ce567ddd7cb6641ac728ce3b7dbd 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -149,9 +149,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): def run( model, - data_loader, - valid_loader, - extra_valid_loaders, + datamodule, checkpoint_period, accelerator, arguments, @@ -263,4 +261,4 @@ def run( callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], ) - _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint) + _ = trainer.fit(model, datamodule, ckpt_path=checkpoint) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 125867bda1aa4f6e2317708cc5010d9120518f46..cae11375b66e8cc24287609363474a36065a2016 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -141,7 +141,7 @@ class PASA(pl.LightningModule): return {"loss": training_loss} - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[1] labels = batch[2] @@ -159,9 +159,12 @@ class PASA(pl.LightningModule): ) validation_loss = self.hparams.criterion_valid(outputs, labels.double()) - return {"validation_loss": validation_loss} + if dataloader_idx == 0: + return {"validation_loss": validation_loss} + else: + return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): names = batch[0] images = batch[1] diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 12c5a287f5682340ecb1275f4439ebd4056c657b..eb6910cdce997df94155e602591508b99ec512dd 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -13,25 +13,6 @@ from ..utils.checkpointer import get_checkpoint logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def set_reproducible_cuda(): - """Turns-off all CUDA optimizations that would affect reproducibility. - - For full reproducibility, also ensure not to use multiple (parallel) data - lowers. That is setup ``num_workers=0``. - - Reference: `PyTorch page for reproducibility - <https://pytorch.org/docs/stable/notes/randomness.html>`_. - """ - import torch.backends.cudnn - - # ensure to use only optimization algos for cuda that are known to have - # a deterministic effect (not random) - torch.backends.cudnn.deterministic = True - - # turns off any optimization tricks - torch.backends.cudnn.benchmark = False - - @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, @@ -62,6 +43,12 @@ def set_reproducible_cuda(): required=True, cls=ResourceOption, ) +@click.option( + "--datamodule", + help="A torch.nn.Module instance implementing the network to be trained", + required=True, + cls=ResourceOption, +) @click.option( "--dataset", "-d", @@ -235,7 +222,7 @@ def set_reproducible_cuda(): ) @click.option( "--resume-from", - help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", + help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", type=str, required=False, default=None, @@ -251,6 +238,7 @@ def train( drop_incomplete_batch, criterion, criterion_valid, + datamodule, dataset, checkpoint_period, accelerator, @@ -277,45 +265,13 @@ def train( import torch.nn from torch.nn import BCEWithLogitsLoss - from torch.utils.data import DataLoader, WeightedRandomSampler + from torch.utils.data import DataLoader - from ..configs.datasets import get_positive_weights, get_samples_weights + from ..configs.datasets import get_positive_weights from ..engine.trainer import run seed_everything(seed) - use_dataset = dataset - validation_dataset = None - extra_validation_datasets = [] - - if isinstance(dataset, dict): - if "__train__" in dataset: - logger.info("Found (dedicated) '__train__' set for training") - use_dataset = dataset["__train__"] - else: - use_dataset = dataset["train"] - - if "__valid__" in dataset: - logger.info("Found (dedicated) '__valid__' set for validation") - logger.info("Will checkpoint lowest loss model on validation set") - validation_dataset = dataset["__valid__"] - - if "__extra_valid__" in dataset: - if not isinstance(dataset["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(dataset['__extra_valid__'])}, " - f"which is invalid." - ) - logger.info( - f"Found {len(dataset['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = dataset["__extra_valid__"] - # PyTorch dataloader multiproc_kwargs = dict() if parallel < 0: @@ -340,31 +296,25 @@ def train( else: batch_chunk_size = batch_size // batch_chunk_count - # Create weighted random sampler - train_samples_weights = get_samples_weights(use_dataset) - train_sampler = WeightedRandomSampler( - train_samples_weights, len(train_samples_weights), replacement=True + datamodule = datamodule( + dataset, + train_batch_size=batch_chunk_size, + multiproc_kwargs=multiproc_kwargs, ) + # Manually calling these as we need to access some values to reweight the criterion + datamodule.prepare_data() + datamodule.setup(stage="fit") + + train_dataset = datamodule.train_dataset + validation_dataset = datamodule.validation_dataset # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = get_positive_weights(use_dataset) + positive_weights = get_positive_weights(train_dataset) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") - # PyTorch dataloader - - data_loader = DataLoader( - dataset=use_dataset, - batch_size=batch_chunk_size, - drop_last=drop_incomplete_batch, - pin_memory=torch.cuda.is_available(), - sampler=train_sampler, - **multiproc_kwargs, - ) - - valid_loader = None if validation_dataset is not None: # Redefine a weighted valid criterion if possible if ( @@ -378,27 +328,6 @@ def train( else: logger.warning("Weighted valid criterion not supported") - valid_loader = DataLoader( - dataset=validation_dataset, - batch_size=batch_chunk_size, - shuffle=False, - drop_last=False, - pin_memory=torch.cuda.is_available(), - **multiproc_kwargs, - ) - - extra_valid_loaders = [ - DataLoader( - dataset=k, - batch_size=batch_chunk_size, - shuffle=False, - drop_last=False, - pin_memory=torch.cuda.is_available(), - **multiproc_kwargs, - ) - for k in extra_validation_datasets - ] - # Create z-normalization model layer if needed if normalization == "imagenet": model.normalizer.set_mean_std( @@ -407,7 +336,9 @@ def train( logger.info("Z-normalization with ImageNet mean and std") elif normalization == "current": # Compute mean/std of current train subset - temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset)) + temp_dl = DataLoader( + dataset=train_dataset, batch_size=len(train_dataset) + ) data = next(iter(temp_dl)) mean = data[1].mean(dim=[0, 2, 3]) @@ -446,9 +377,7 @@ def train( run( model=model, - data_loader=data_loader, - valid_loader=valid_loader, - extra_valid_loaders=extra_valid_loaders, + datamodule=datamodule, checkpoint_period=checkpoint_period, accelerator=accelerator, arguments=arguments,