From 8769c32708572dfd15f57e5c56abc9bf4f2b451a Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 2 May 2024 12:27:32 +0200 Subject: [PATCH] [model] Use a single type of loss for train and validation --- src/mednet/config/models/alexnet.py | 3 +- .../config/models/alexnet_pretrained.py | 3 +- src/mednet/config/models/densenet.py | 3 +- .../config/models/densenet_pretrained.py | 3 +- src/mednet/config/models/densenet_rs.py | 3 +- src/mednet/config/models/pasa.py | 3 +- src/mednet/models/alexnet.py | 31 +++------- src/mednet/models/densenet.py | 31 +++------- src/mednet/models/model.py | 58 ++++++------------- src/mednet/models/pasa.py | 31 +++------- 10 files changed, 47 insertions(+), 122 deletions(-) diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py index 9a1cf542..7f281867 100644 --- a/src/mednet/config/models/alexnet.py +++ b/src/mednet/config/models/alexnet.py @@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py index ea9198ab..a9356555 100644 --- a/src/mednet/config/models/alexnet_pretrained.py +++ b/src/mednet/config/models/alexnet_pretrained.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index 7154bb74..9ee510ac 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index 1025a689..b7e2efcd 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index 23c1a064..813bb76c 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py index 16457d1a..7787d10e 100644 --- a/src/mednet/config/models/pasa.py +++ b/src/mednet/config/models/pasa.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.pasa import Pasa model = Pasa( - train_loss_type=BCEWithLogitsLoss, - validation_loss_type=BCEWithLogitsLoss, + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index d4e2586d..75223c9a 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -27,26 +27,15 @@ class Alexnet(Model): Parameters ---------- - train_loss_type - The loss to be used during the training. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - train_loss_arguments - Arguments to the training loss. - validation_loss_type - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss_arguments - Arguments to the validation loss. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -63,10 +52,8 @@ class Alexnet(Model): def __init__( self, - train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, - train_loss_arguments: dict[str, typing.Any] = {}, - validation_loss_type: torch.nn.Module | None = None, - validation_loss_arguments: dict[str, typing.Any] = {}, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -74,10 +61,8 @@ class Alexnet(Model): num_classes: int = 1, ): super().__init__( - train_loss_type, - train_loss_arguments, - validation_loss_type, - validation_loss_arguments, + loss_type, + loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index c90ab10c..76df1ed6 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -25,26 +25,15 @@ class Densenet(Model): Parameters ---------- - train_loss_type - The loss to be used during the training. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - train_loss_arguments - Arguments to the training loss. - validation_loss_type - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss_arguments - Arguments to the validation loss. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -63,10 +52,8 @@ class Densenet(Model): def __init__( self, - train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, - train_loss_arguments: dict[str, typing.Any] = {}, - validation_loss_type: torch.nn.Module | None = None, - validation_loss_arguments: dict[str, typing.Any] = {}, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -75,10 +62,8 @@ class Densenet(Model): num_classes: int = 1, ): super().__init__( - train_loss_type, - train_loss_arguments, - validation_loss_type, - validation_loss_arguments, + loss_type, + loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 27e5c32d..57c1b618 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -24,26 +24,15 @@ class Model(pl.LightningModule): Parameters ---------- - train_loss_type - The loss to be used during the training. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - train_loss_arguments - Arguments to the training loss. - validation_loss_type - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss_arguments - Arguments to the validation loss. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -57,10 +46,8 @@ class Model(pl.LightningModule): def __init__( self, - train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, - train_loss_arguments: dict[str, typing.Any] = {}, - validation_loss_type: torch.nn.Module | None = None, - validation_loss_arguments: dict[str, typing.Any] = {}, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -73,13 +60,13 @@ class Model(pl.LightningModule): self.model_transforms: TransformSequence = [] - self._train_loss_type = train_loss_type - self._train_loss_arguments = train_loss_arguments + self._loss_type = loss_type + self._train_loss = None + self._train_loss_arguments = loss_arguments - self._validation_loss_type = validation_loss_type - self._validation_loss_arguments = validation_loss_arguments self.validation_loss = None + self._validation_loss_arguments = loss_arguments self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments @@ -148,8 +135,8 @@ class Model(pl.LightningModule): raise NotImplementedError def configure_losses(self): - self._train_loss = self._train_loss_type(**self._train_loss_arguments) - self._validation_loss = self._validation_loss_type( + self._train_loss = self._loss_type(**self._train_loss_arguments) + self._validation_loss = self._loss_type( **self._validation_loss_arguments ) @@ -160,7 +147,7 @@ class Model(pl.LightningModule): ) def balance_losses(self, datamodule) -> None: - """Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it. + """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute). Parameters ---------- @@ -168,29 +155,18 @@ class Model(pl.LightningModule): Instance of a datamodule. """ - logger.info( - f"Balancing training loss function {self._train_loss_type}." - ) try: - getattr(self._train_loss_type(), "pos_weight") + getattr(self._loss_type(), "pos_weight") except AttributeError: logger.warning( - "Training loss does not posess a 'pos_weight' attribute and will not be balanced." + f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced." ) else: + logger.info(f"Balancing training loss {self._loss_type}.") train_weights = get_positive_weights(datamodule.train_dataloader()) self._train_loss_arguments["pos_weight"] = train_weights - logger.info( - f"Balancing validation loss function {self._validation_loss_type}." - ) - try: - getattr(self._validation_loss_type(), "pos_weight") - except AttributeError: - logger.warning( - "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." - ) - else: + logger.info(f"Balancing validation loss {self._loss_type}.") validation_weights = get_positive_weights( datamodule.val_dataloader()["validation"] ) diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 5e5cd743..e9147683 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -30,26 +30,15 @@ class Pasa(Model): Parameters ---------- - train_loss_type - The loss to be used during the training. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - train_loss_arguments - Arguments to the training loss. - validation_loss_type - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss_arguments - Arguments to the validation loss. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -63,20 +52,16 @@ class Pasa(Model): def __init__( self, - train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, - train_loss_arguments: dict[str, typing.Any] = {}, - validation_loss_type: torch.nn.Module | None = None, - validation_loss_arguments: dict[str, typing.Any] = {}, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): super().__init__( - train_loss_type, - train_loss_arguments, - validation_loss_type, - validation_loss_arguments, + loss_type, + loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, -- GitLab