diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 2de54fdf96d95ac4f553c43b548d168dcf2bace4..6c7d759f4a8890f576af327450ac89e144f6fbfa 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -758,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule): else: self._datasets[name] = _ConcatDataset(datasets) - def _val_dataset_keys(self) -> list[str]: + def val_dataset_keys(self) -> list[str]: """Return list of validation dataset names. Returns @@ -796,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule): """ if stage == "fit": - for k in ["train"] + self._val_dataset_keys(): + for k in ["train"] + self.val_dataset_keys(): self._setup_dataset(k) elif stage == "validate": - for k in self._val_dataset_keys(): + for k in self.val_dataset_keys(): self._setup_dataset(k) elif stage == "test": @@ -889,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[k], **validation_loader_opts, ) - for k in self._val_dataset_keys() + for k in self.val_dataset_keys() } def test_dataloader(self) -> dict[str, DataLoader]: diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index 4c19ac55e5d325c6506dff393f70f0503e0d5682..501f2c254b2d511993adadab2735f12d44fe73c1 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -362,13 +362,8 @@ class LoggingCallback(lightning.pytorch.Callback): out which dataset was used for this validation epoch. """ - if dataloader_idx == 0: - key = "loss/validation" - else: - key = f"loss/validation-{dataloader_idx}" - pl_module.log( - key, + "loss/validation", outputs.item(), prog_bar=False, on_step=False, diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 23df024bcf5efbc73c2c2b6bc207df3d5889a77e..d3a345b71d357c732075f482954e41b905591093 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -91,6 +91,10 @@ def run( main_pid=os.getpid(), ) + monitor_key = "loss/validation" + if len(datamodule.val_dataset_keys()) > 1: + monitor_key = "loss/validation/dataloader_idx_0" + # This checkpointer will operate at the end of every validation epoch # (which happens at each checkpoint period), it will then save the lowest # validation loss model observed. It will also save the last trained model @@ -98,7 +102,7 @@ def run( dirpath=output_folder, filename=CHECKPOINT_ALIASES["best"], save_last=True, # will (re)create the last trained model, at every iteration - monitor="loss/validation", + monitor=monitor_key, mode="min", save_on_train_epoch_end=True, every_n_epochs=validation_period, # frequency at which it checks the "monitor" diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 22b98baa42d0738704e0bc619a72c38d59208903..eada55b8bbb6c49fb43e1fd72bc31fb05b3444f1 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -166,8 +166,7 @@ class Alexnet(Model): # data forwarding on the existing network outputs = self(images) - - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index fcdb9f95de1d856bfb3a5dd109f471e8202c4e66..15da7f4eaded9618849c29a32cdaff97a9ddb9cb 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -164,7 +164,7 @@ class Densenet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index a0b3701eae4376481a3875e5c53a91ed6976fe05..155ff19c0c0887b9a0504b414eb93242b7b0ff63 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -68,9 +68,10 @@ class Model(pl.LightningModule): self.model_transforms: TransformSequence = [] self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) + self._validation_loss = [ + (validation_loss if validation_loss is not None else train_loss) + ] + self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments @@ -163,16 +164,32 @@ class Model(pl.LightningModule): setattr(self._train_loss, "pos_weight", train_weights) logger.info( - f"Balancing validation loss function {self._validation_loss}." + f"Balancing validation loss function {self._validation_loss[0]}." ) try: - getattr(self._validation_loss, "pos_weight") + getattr(self._validation_loss[0], "pos_weight") except AttributeError: logger.warning( "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: - validation_weights = _get_label_weights( - datamodule.val_dataloader()["validation"] + # If multiple validation DataLoaders are used, each one will need to have a loss + # that is balanced for that DataLoader + + new_validation_losses = [] + loss_class = self._validation_loss[0].__class__ + + datamodule_validation_keys = datamodule.val_dataset_keys() + logger.info( + f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key." ) - setattr(self._validation_loss, "pos_weight", validation_weights) + + for val_dataset_key in datamodule_validation_keys: + validation_weights = _get_label_weights( + datamodule.val_dataloader()[val_dataset_key] + ) + new_validation_losses.append( + loss_class(pos_weight=validation_weights) + ) + + self._validation_loss = new_validation_losses diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 389eac8cacd8c4839012cbd1e5c6d0f3d0b14570..54032edad5feee198b47274174b99e9a33fc8521 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -233,8 +233,7 @@ class Pasa(Model): # data forwarding on the existing network outputs = self(images) - - return self._validation_loss(outputs, labels.float()) + return self._validation_loss[dataloader_idx](outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0])