From 1e2c50e0b0b07ddc6e6212581ccfb757756d29d5 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 25 Apr 2024 10:01:46 +0200 Subject: [PATCH] [lwnet] Add parameter for crop size --- src/mednet/libs/segmentation/config/models/lwnet.py | 1 + src/mednet/libs/segmentation/models/lwnet.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py index 9cb08ea3..09af9f84 100644 --- a/src/mednet/libs/segmentation/config/models/lwnet.py +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -23,4 +23,5 @@ model = LittleWNet( optimizer_type=Adam, optimizer_arguments=dict(lr=max_lr), augmentation_transforms=[], + crop_size=544, ) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 6cad67c9..d8e5b2cb 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule): applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. + crop_size + The size of the image after center cropping. """ def __init__( @@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, + crop_size: int = 544, ): super().__init__() self.name = "lwnet" self.num_classes = num_classes - self.model_transforms = [CenterCrop(size=(544, 544))] + self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] self._train_loss = train_loss self._validation_loss = ( -- GitLab