Skip to content
Snippets Groups Projects
Commit 59aefc94 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[lwnet] Common loss for train and validation

parent 24c7fb68
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -18,8 +18,7 @@ min_lr = 1e-08 # valley ...@@ -18,8 +18,7 @@ min_lr = 1e-08 # valley
cycle = 50 # epochs for a complete scheduling cycle cycle = 50 # epochs for a complete scheduling cycle
model = LittleWNet( model = LittleWNet(
train_loss=MultiWeightedBCELogitsLoss(), loss_type=MultiWeightedBCELogitsLoss,
validation_loss=MultiWeightedBCELogitsLoss(),
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=max_lr), optimizer_arguments=dict(lr=max_lr),
augmentation_transforms=[], augmentation_transforms=[],
......
...@@ -17,11 +17,10 @@ Reference: [GALDRAN-2020]_ ...@@ -17,11 +17,10 @@ Reference: [GALDRAN-2020]_
import typing import typing
import lightning.pytorch as pl
import torch import torch
import torch.nn import torch.nn
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from torchvision.transforms.v2 import CenterCrop from torchvision.transforms.v2 import CenterCrop
...@@ -230,27 +229,20 @@ class LittleUNet(torch.nn.Module): ...@@ -230,27 +229,20 @@ class LittleUNet(torch.nn.Module):
return self.final(x) return self.final(x)
class LittleWNet(pl.LightningModule): class LittleWNet(Model):
"""Little W-Net model, concatenating two Little U-Net models. """Little W-Net model, concatenating two Little U-Net models.
Parameters Parameters
---------- ----------
train_loss loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -266,32 +258,28 @@ class LittleWNet(pl.LightningModule): ...@@ -266,32 +258,28 @@ class LittleWNet(pl.LightningModule):
def __init__( def __init__(
self, self,
train_loss=MultiWeightedBCELogitsLoss(), loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
validation_loss=MultiWeightedBCELogitsLoss(), loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
crop_size: int = 544, crop_size: int = 544,
): ):
super().__init__() super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "lwnet" self.name = "lwnet"
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.unet1 = LittleUNet( self.unet1 = LittleUNet(
in_c=3, in_c=3,
n_classes=self.num_classes, n_classes=self.num_classes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment