Skip to content
Snippets Groups Projects
Commit 59aefc94 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[lwnet] Common loss for train and validation

parent 24c7fb68
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -18,8 +18,7 @@ min_lr = 1e-08 # valley
cycle = 50 # epochs for a complete scheduling cycle
model = LittleWNet(
train_loss=MultiWeightedBCELogitsLoss(),
validation_loss=MultiWeightedBCELogitsLoss(),
loss_type=MultiWeightedBCELogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=max_lr),
augmentation_transforms=[],
......
......@@ -17,11 +17,10 @@ Reference: [GALDRAN-2020]_
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from torchvision.transforms.v2 import CenterCrop
......@@ -230,27 +229,20 @@ class LittleUNet(torch.nn.Module):
return self.final(x)
class LittleWNet(pl.LightningModule):
class LittleWNet(Model):
"""Little W-Net model, concatenating two Little U-Net models.
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -266,32 +258,28 @@ class LittleWNet(pl.LightningModule):
def __init__(
self,
train_loss=MultiWeightedBCELogitsLoss(),
validation_loss=MultiWeightedBCELogitsLoss(),
loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
crop_size: int = 544,
):
super().__init__()
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "lwnet"
self.num_classes = num_classes
self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.unet1 = LittleUNet(
in_c=3,
n_classes=self.num_classes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment