From 11b65a401b8a21605a1280d44c32b0f99eb46d9c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 4 Jul 2023 10:57:39 +0200 Subject: [PATCH] Functional model training step for pasa Implemented BCEWithLogitsLoss reweighting function Removed save_hyperparameters from model Apply augmentation transforms on singular images Fixed model summary --- src/ptbench/configs/models/pasa.py | 13 ++++--- src/ptbench/data/datamodule.py | 12 +++++- src/ptbench/data/dataset.py | 49 +++--------------------- src/ptbench/engine/callbacks.py | 4 +- src/ptbench/engine/trainer.py | 12 ++++-- src/ptbench/models/pasa.py | 61 ++++++++++++++++++++++-------- 6 files changed, 78 insertions(+), 73 deletions(-) diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 47324199..49ee76db 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -13,22 +13,25 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam -from ...data.transforms import ElasticDeformation from ...models.pasa import PASA -# config -optimizer_configs = {"lr": 8e-5} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 8e-5} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + augmentation_transforms = [ElasticDeformation(p=0.8)] +# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode +# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)] + # model model = PASA( criterion, diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 8c297823..fc9883d2 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -13,6 +13,8 @@ import torch import torch.utils.data import torchvision.transforms +from tqdm import tqdm + logger = logging.getLogger(__name__) @@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset): typing.Callable[[torch.Tensor], torch.Tensor] ] = [], ): - self.transform = torchvision.transforms.Compose(*transforms) - self.data = [raw_data_loader(k) for k in split] + # Cannot unpack empty list + if len(transforms) > 0: + self.transform = torchvision.transforms.Compose([*transforms]) + else: + self.transform = torchvision.transforms.Compose([]) + + self.data = [raw_data_loader(k) for k in tqdm(split)] def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.data[key] @@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule): logger.info(f"Dataset {name} is already setup. Not reloading it.") return if self.cache_samples: + logger.info(f"Caching {name} dataset") self._datasets[name] = _CachedDataset( self.database_split[name], self.raw_data_loader, diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index af8a7364..3529ee7f 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -10,11 +10,11 @@ import torch.utils.data logger = logging.getLogger(__name__) -def _get_positive_weights(dataset): +def _get_positive_weights(dataloader): """Compute the positive weights of each class of the dataset to balance the BCEWithLogitsLoss criterion. - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + This function takes as input a :py:class:`torch.utils.data.DataLoader` and computes the positive weights of each class to use them to have a balanced loss. @@ -22,9 +22,8 @@ def _get_positive_weights(dataset): Parameters ---------- - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported + dataloader : :py:class:`torch.utils.data.DataLoader` + A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__(). Returns @@ -35,14 +34,8 @@ def _get_positive_weights(dataset): """ targets = [] - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s["label"]) - - else: - for s in dataset._samples: - targets.append(s["label"]) + for batch in dataloader: + targets.extend(batch[1]["label"]) targets = torch.tensor(targets) @@ -71,33 +64,3 @@ def _get_positive_weights(dataset): ) return positive_weights - - -def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): - from torch.nn import BCEWithLogitsLoss - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - train_dataset = datamodule.train_dataset - validation_dataset = datamodule.validation_dataset - - # Redefine a weighted criterion if possible - if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = _get_positive_weights(train_dataset) - model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) - else: - logger.warning("Weighted criterion not supported") - - if validation_dataset is not None: - # Redefine a weighted valid criterion if possible - if ( - isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) - or criterion_valid is None - ): - positive_weights = _get_positive_weights(validation_dataset) - model.hparams.criterion_valid = BCEWithLogitsLoss( - pos_weight=positive_weights - ) - else: - logger.warning("Weighted valid criterion not supported") diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index d0ac43f9..580cc26f 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -94,9 +94,7 @@ class LoggingCallback(Callback): self.log("total_time", current_time) self.log("eta", eta_seconds) self.log("loss", numpy.average(self.training_loss)) - self.log( - "learning_rate", pl_module.hparams["optimizer_configs"]["lr"] - ) + self.log("learning_rate", pl_module.optimizer_configs["lr"]) self.log("validation_loss", numpy.sum(self.validation_loss)) if len(self.extra_validation_loss) > 0: diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 6b156f86..ecf29153 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -51,20 +51,24 @@ def save_model_summary( Returns ------- - r + summary: The model summary in a text format. - n + total_parameters: The number of parameters of the model. """ summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1) + summary = lightning.pytorch.utilities.model_summary.ModelSummary( + model, max_depth=-1 + ) f.write(str(summary)) return ( summary, - lightning.pytorch.callbacks.ModelSummary(model).total_parameters, + lightning.pytorch.utilities.model_summary.ModelSummary( + model + ).total_parameters, ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 76327670..e93b4e61 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data +import torchvision.transforms logger = logging.getLogger(__name__) @@ -31,7 +32,15 @@ class PASA(pl.LightningModule): self.name = "pasa" - self.augmentation_transforms = augmentation_transforms + self.augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.optimizer = optimizer + self.optimizer_configs = optimizer_configs self.normalizer = None @@ -137,7 +146,7 @@ class PASA(pl.LightningModule): Parameters ---------- - dataloader: + dataloader: :py:class:`torch.utils.data.DataLoader` A torch Dataloader from which to compute the mean and std """ from .normalizer import make_z_normalizer @@ -148,6 +157,35 @@ class PASA(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) + def set_bce_loss_weights(self, datamodule): + """Reweights loss weights if BCEWithLogitsLoss is used. + + Parameters + ---------- + + datamodule: + A datamodule implementing train_dataloader() and val_dataloader() + """ + from ..data.dataset import _get_positive_weights + + if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training criterion.") + train_positive_weights = _get_positive_weights( + datamodule.train_dataloader() + ) + self.criterion = torch.nn.BCEWithLogitsLoss( + pos_weight=train_positive_weights + ) + + if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + validation_positive_weights = _get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + self.criterion_valid = torch.nn.BCEWithLogitsLoss( + pos_weight=validation_positive_weights + ) + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -158,12 +196,11 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = self.augmentation_transforms(images) + augmented_images = [self.augmentation_transforms(img) for img in images] + augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1) outputs = self(augmented_images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -179,11 +216,7 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.criterion_valid(outputs, labels.double()) if dataloader_idx == 0: return {"validation_loss": validation_loss} @@ -233,9 +266,5 @@ class PASA(pl.LightningModule): # raise NotImplementedError def configure_optimizers(self): - # Dynamically instantiates the optimizer given the configs - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs - ) - + optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) return optimizer -- GitLab