From ea69e43173bae83aebf614f7433c139df9ce6ce4 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 1 May 2024 14:26:24 +0200 Subject: [PATCH] [model] Model takes loss type and arguments during instanciation --- src/mednet/config/models/alexnet.py | 4 +- .../config/models/alexnet_pretrained.py | 4 +- src/mednet/config/models/densenet.py | 4 +- .../config/models/densenet_pretrained.py | 4 +- src/mednet/config/models/densenet_rs.py | 4 +- src/mednet/config/models/pasa.py | 4 +- src/mednet/engine/trainer.py | 2 + src/mednet/models/alexnet.py | 22 +++++-- src/mednet/models/densenet.py | 22 +++++-- src/mednet/models/model.py | 65 ++++++++++--------- src/mednet/models/pasa.py | 22 +++++-- 11 files changed, 92 insertions(+), 65 deletions(-) diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py index 9703f964..9a1cf542 100644 --- a/src/mednet/config/models/alexnet.py +++ b/src/mednet/config/models/alexnet.py @@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py index 8887db8f..ea9198ab 100644 --- a/src/mednet/config/models/alexnet_pretrained.py +++ b/src/mednet/config/models/alexnet_pretrained.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index f28dd23c..7154bb74 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index 274a5646..1025a689 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index e7db4885..23c1a064 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py index 227b9b42..16457d1a 100644 --- a/src/mednet/config/models/pasa.py +++ b/src/mednet/config/models/pasa.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.pasa import Pasa model = Pasa( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index d3a345b7..0993354f 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -72,6 +72,8 @@ def run( output_folder.mkdir(parents=True, exist_ok=True) + model.configure_losses() + from .loggers import CustomTensorboardLogger log_dir = "logs" diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index eada55b8..d4e2586d 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -27,14 +27,16 @@ class Alexnet(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -43,6 +45,8 @@ class Alexnet(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,8 +63,10 @@ class Alexnet(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -68,8 +74,10 @@ class Alexnet(Model): num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -166,7 +174,7 @@ class Alexnet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 15da7f4e..c90ab10c 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -25,14 +25,16 @@ class Densenet(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -41,6 +43,8 @@ class Densenet(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,8 +63,10 @@ class Densenet(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -69,8 +75,10 @@ class Densenet(Model): num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -164,7 +172,7 @@ class Densenet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index e5a13dc8..27e5c32d 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -24,14 +24,16 @@ class Model(pl.LightningModule): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -40,6 +42,8 @@ class Model(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -53,8 +57,10 @@ class Model(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -67,10 +73,13 @@ class Model(pl.LightningModule): self.model_transforms: TransformSequence = [] - self._train_loss = train_loss - self._validation_loss = [ - (validation_loss if validation_loss is not None else train_loss) - ] + self._train_loss_type = train_loss_type + self._train_loss_arguments = train_loss_arguments + self._train_loss = None + + self._validation_loss_type = validation_loss_type + self._validation_loss_arguments = validation_loss_arguments + self.validation_loss = None self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments @@ -138,6 +147,12 @@ class Model(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): raise NotImplementedError + def configure_losses(self): + self._train_loss = self._train_loss_type(**self._train_loss_arguments) + self._validation_loss = self._validation_loss_type( + **self._validation_loss_arguments + ) + def configure_optimizers(self): return self._optimizer_type( self.parameters(), @@ -153,44 +168,30 @@ class Model(pl.LightningModule): Instance of a datamodule. """ - logger.info(f"Balancing training loss function {self._train_loss}.") + logger.info( + f"Balancing training loss function {self._train_loss_type}." + ) try: - getattr(self._train_loss, "pos_weight") + getattr(self._train_loss_type(), "pos_weight") except AttributeError: logger.warning( "Training loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: train_weights = get_positive_weights(datamodule.train_dataloader()) - setattr(self._train_loss, "pos_weight", train_weights) + self._train_loss_arguments["pos_weight"] = train_weights logger.info( - f"Balancing validation loss function {self._validation_loss[0]}." + f"Balancing validation loss function {self._validation_loss_type}." ) try: - getattr(self._validation_loss[0], "pos_weight") + getattr(self._validation_loss_type(), "pos_weight") except AttributeError: logger.warning( "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: - # If multiple validation DataLoaders are used, each one will need to have a loss - # that is balanced for that DataLoader - - new_validation_losses = [] - loss_class = self._validation_loss[0].__class__ - - datamodule_validation_keys = datamodule.val_dataset_keys() - logger.info( - f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key." + validation_weights = get_positive_weights( + datamodule.val_dataloader()["validation"] ) - - for val_dataset_key in datamodule_validation_keys: - validation_weights = get_positive_weights( - datamodule.val_dataloader()[val_dataset_key] - ) - new_validation_losses.append( - loss_class(pos_weight=validation_weights) - ) - - self._validation_loss = new_validation_losses + self._validation_loss_arguments["pos_weight"] = validation_weights diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 54032eda..5e5cd743 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -30,14 +30,16 @@ class Pasa(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -46,6 +48,8 @@ class Pasa(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,16 +63,20 @@ class Pasa(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -233,7 +241,7 @@ class Pasa(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) -- GitLab