diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py index 9703f964a476b53d5aa076242bb0b02cedfe75ff..9a1cf54207a09ea75e862ac0548640241af8a705 100644 --- a/src/mednet/config/models/alexnet.py +++ b/src/mednet/config/models/alexnet.py @@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py index 8887db8f6f006cd2580dabac44202055a5cdacab..ea9198aba28bd67c76ec814c3456c6cf367d4c03 100644 --- a/src/mednet/config/models/alexnet_pretrained.py +++ b/src/mednet/config/models/alexnet_pretrained.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index f28dd23cd12c72e5fc6713e706f0e9c05158759c..7154bb740cb9db3c572f689b9fc6216e9862a041 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index 274a564601094a8ecb51e67c87f19f1f8197a30a..1025a68936a97aa63c5b84029efc9e75d16ae7ef 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index e7db48850d0e8d2b959b39ee93bae3b78dccfa80..23c1a06468c877350b873f064b8eb274c1ee6edc 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py index 227b9b426568bebf327c1ac3f206a5fc3f2b44b6..16457d1a960ab0c1204fd272fc1fb95330fe4704 100644 --- a/src/mednet/config/models/pasa.py +++ b/src/mednet/config/models/pasa.py @@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.pasa import Pasa model = Pasa( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + train_loss_type=BCEWithLogitsLoss, + validation_loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index d3a345b71d357c732075f482954e41b905591093..0993354f8dc2d84b30f0ab18f95751621a6076fb 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -72,6 +72,8 @@ def run( output_folder.mkdir(parents=True, exist_ok=True) + model.configure_losses() + from .loggers import CustomTensorboardLogger log_dir = "logs" diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index eada55b8bbb6c49fb43e1fd72bc31fb05b3444f1..d4e2586d2c101f0b0b04a2d000447398d0f21a4c 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -27,14 +27,16 @@ class Alexnet(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -43,6 +45,8 @@ class Alexnet(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,8 +63,10 @@ class Alexnet(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -68,8 +74,10 @@ class Alexnet(Model): num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -166,7 +174,7 @@ class Alexnet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 15da7f4eaded9618849c29a32cdaff97a9ddb9cb..c90ab10cefdb6094067b1dc00ce90c9102de4ba4 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -25,14 +25,16 @@ class Densenet(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -41,6 +43,8 @@ class Densenet(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,8 +63,10 @@ class Densenet(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -69,8 +75,10 @@ class Densenet(Model): num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -164,7 +172,7 @@ class Densenet(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index e5a13dc8336aa28335bba58d26e98d0814ac9ed8..27e5c32d78ae866651c79e377fbfaa2663c4c35a 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -24,14 +24,16 @@ class Model(pl.LightningModule): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -40,6 +42,8 @@ class Model(pl.LightningModule): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -53,8 +57,10 @@ class Model(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -67,10 +73,13 @@ class Model(pl.LightningModule): self.model_transforms: TransformSequence = [] - self._train_loss = train_loss - self._validation_loss = [ - (validation_loss if validation_loss is not None else train_loss) - ] + self._train_loss_type = train_loss_type + self._train_loss_arguments = train_loss_arguments + self._train_loss = None + + self._validation_loss_type = validation_loss_type + self._validation_loss_arguments = validation_loss_arguments + self.validation_loss = None self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments @@ -138,6 +147,12 @@ class Model(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): raise NotImplementedError + def configure_losses(self): + self._train_loss = self._train_loss_type(**self._train_loss_arguments) + self._validation_loss = self._validation_loss_type( + **self._validation_loss_arguments + ) + def configure_optimizers(self): return self._optimizer_type( self.parameters(), @@ -153,44 +168,30 @@ class Model(pl.LightningModule): Instance of a datamodule. """ - logger.info(f"Balancing training loss function {self._train_loss}.") + logger.info( + f"Balancing training loss function {self._train_loss_type}." + ) try: - getattr(self._train_loss, "pos_weight") + getattr(self._train_loss_type(), "pos_weight") except AttributeError: logger.warning( "Training loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: train_weights = get_positive_weights(datamodule.train_dataloader()) - setattr(self._train_loss, "pos_weight", train_weights) + self._train_loss_arguments["pos_weight"] = train_weights logger.info( - f"Balancing validation loss function {self._validation_loss[0]}." + f"Balancing validation loss function {self._validation_loss_type}." ) try: - getattr(self._validation_loss[0], "pos_weight") + getattr(self._validation_loss_type(), "pos_weight") except AttributeError: logger.warning( "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: - # If multiple validation DataLoaders are used, each one will need to have a loss - # that is balanced for that DataLoader - - new_validation_losses = [] - loss_class = self._validation_loss[0].__class__ - - datamodule_validation_keys = datamodule.val_dataset_keys() - logger.info( - f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key." + validation_weights = get_positive_weights( + datamodule.val_dataloader()["validation"] ) - - for val_dataset_key in datamodule_validation_keys: - validation_weights = get_positive_weights( - datamodule.val_dataloader()[val_dataset_key] - ) - new_validation_losses.append( - loss_class(pos_weight=validation_weights) - ) - - self._validation_loss = new_validation_losses + self._validation_loss_arguments["pos_weight"] = validation_weights diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 54032edad5feee198b47274174b99e9a33fc8521..5e5cd743eae3f4abab99797137da723838c682eb 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -30,14 +30,16 @@ class Pasa(Model): Parameters ---------- - train_loss + train_loss_type The loss to be used during the training. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. - validation_loss + train_loss_arguments + Arguments to the training loss. + validation_loss_type The loss to be used for validation (may be different from the training loss). If extra-validation sets are provided, the same loss will be used throughout. @@ -46,6 +48,8 @@ class Pasa(Model): The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + validation_loss_arguments + Arguments to the validation loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -59,16 +63,20 @@ class Pasa(Model): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + train_loss_arguments: dict[str, typing.Any] = {}, + validation_loss_type: torch.nn.Module | None = None, + validation_loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): super().__init__( - train_loss, - validation_loss, + train_loss_type, + train_loss_arguments, + validation_loss_type, + validation_loss_arguments, optimizer_type, optimizer_arguments, augmentation_transforms, @@ -233,7 +241,7 @@ class Pasa(Model): # data forwarding on the existing network outputs = self(images) - return self._validation_loss[dataloader_idx](outputs, labels.float()) + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0])