From 222a376213777262323c77f229f3a9561b8aa836 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 30 Apr 2024 10:37:09 +0200 Subject: [PATCH] [train] Add support for multiple validation dataloaders During validation logging, lightning seems to append "/dataloader_idx_n" to the key we define if multiple dataloaders are used. --- src/mednet/data/datamodule.py | 8 ++++---- src/mednet/engine/callbacks.py | 7 +------ src/mednet/engine/trainer.py | 6 +++++- src/mednet/models/alexnet.py | 3 +-- src/mednet/models/densenet.py | 2 +- src/mednet/models/model.py | 33 +++++++++++++++++++++++++-------- src/mednet/models/pasa.py | 3 +-- 7 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 2de54fdf..6c7d759f 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -758,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule): else: self._datasets[name] = _ConcatDataset(datasets) - def _val_dataset_keys(self) -> list[str]: + def val_dataset_keys(self) -> list[str]: """Return list of validation dataset names. Returns @@ -796,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule): """ if stage == "fit": - for k in ["train"] + self._val_dataset_keys(): + for k in ["train"] + self.val_dataset_keys(): self._setup_dataset(k) elif stage == "validate": - for k in self._val_dataset_keys(): + for k in self.val_dataset_keys(): self._setup_dataset(k) elif stage == "test": @@ -889,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[k], **validation_loader_opts, ) - for k in self._val_dataset_keys() + for k in self.val_dataset_keys() } def test_dataloader(self) -> dict[str, DataLoader]: diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index 4c19ac55..501f2c25 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -362,13 +362,8 @@ class LoggingCallback(lightning.pytorch.Callback): out which dataset was used for this validation epoch. """ - if dataloader_idx == 0: - key = "loss/validation" - else: - key = f"loss/validation-{dataloader_idx}" - pl_module.log( - key, + "loss/validation", outputs.item(), prog_bar=False, on_step=False, diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 23df024b..d3a345b7 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -91,6 +91,10 @@ def run( main_pid=os.getpid(), ) + monitor_key = "loss/validation" + if len(datamodule.val_dataset_keys()) > 1: + monitor_key = "loss/validation/dataloader_idx_0" + # This checkpointer will operate at the end of every validation epoch # (which happens at each checkpoint period), it will then save the lowest # validation loss model observed. It will also save the last trained model @@ -98,7 +102,7 @@ def run( dirpath=output_folder, filename=CHECKPOINT_ALIASES["best"], save_last=True, # will (re)create the last trained model, at every iteration - monitor="loss/validation", + monitor=monitor_key, mode="min", save_on_train_epoch_end=True, every_n_epochs=validation_period, # frequency at which it checks the "monitor" diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 22b98baa..eada55b8 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -166,8 +166,7 @@ class Alexnet(Model): # data forwarding on the existing network outputs = self(images) - - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index fcdb9f95..15da7f4e 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -164,7 +164,7 @@ class Densenet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index a0b3701e..155ff19c 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -68,9 +68,10 @@ class Model(pl.LightningModule): self.model_transforms: TransformSequence = [] self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) + self._validation_loss = [ + (validation_loss if validation_loss is not None else train_loss) + ] + self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments @@ -163,16 +164,32 @@ class Model(pl.LightningModule): setattr(self._train_loss, "pos_weight", train_weights) logger.info( - f"Balancing validation loss function {self._validation_loss}." + f"Balancing validation loss function {self._validation_loss[0]}." ) try: - getattr(self._validation_loss, "pos_weight") + getattr(self._validation_loss[0], "pos_weight") except AttributeError: logger.warning( "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: - validation_weights = _get_label_weights( - datamodule.val_dataloader()["validation"] + # If multiple validation DataLoaders are used, each one will need to have a loss + # that is balanced for that DataLoader + + new_validation_losses = [] + loss_class = self._validation_loss[0].__class__ + + datamodule_validation_keys = datamodule.val_dataset_keys() + logger.info( + f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key." ) - setattr(self._validation_loss, "pos_weight", validation_weights) + + for val_dataset_key in datamodule_validation_keys: + validation_weights = _get_label_weights( + datamodule.val_dataloader()[val_dataset_key] + ) + new_validation_losses.append( + loss_class(pos_weight=validation_weights) + ) + + self._validation_loss = new_validation_losses diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 389eac8c..54032eda 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -233,8 +233,7 @@ class Pasa(Model): # data forwarding on the existing network outputs = self(images) - - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) -- GitLab