diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py index 09af9f842c731bcaa23e284eb292d90134664883..b0428f4498f6da65ac6549e39c1fec21a7bbf265 100644 --- a/src/mednet/libs/segmentation/config/models/lwnet.py +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -18,8 +18,7 @@ min_lr = 1e-08 # valley cycle = 50 # epochs for a complete scheduling cycle model = LittleWNet( - train_loss=MultiWeightedBCELogitsLoss(), - validation_loss=MultiWeightedBCELogitsLoss(), + loss_type=MultiWeightedBCELogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=max_lr), augmentation_transforms=[], diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index d8e5b2cbd23119eb3d02cd5d2ac9e9342ac3ee88..ac23d617138e7f8f9d04849691b82832930ffd7b 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -17,11 +17,10 @@ Reference: [GALDRAN-2020]_ import typing -import lightning.pytorch as pl import torch import torch.nn -import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from torchvision.transforms.v2 import CenterCrop @@ -230,27 +229,20 @@ class LittleUNet(torch.nn.Module): return self.final(x) -class LittleWNet(pl.LightningModule): +class LittleWNet(Model): """Little W-Net model, concatenating two Little U-Net models. Parameters ---------- - train_loss - The loss to be used during the training. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -266,32 +258,28 @@ class LittleWNet(pl.LightningModule): def __init__( self, - train_loss=MultiWeightedBCELogitsLoss(), - validation_loss=MultiWeightedBCELogitsLoss(), + loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, crop_size: int = 544, ): - super().__init__() + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "lwnet" self.num_classes = num_classes self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms - ) - self.unet1 = LittleUNet( in_c=3, n_classes=self.num_classes,