From 1e2c50e0b0b07ddc6e6212581ccfb757756d29d5 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 25 Apr 2024 10:01:46 +0200
Subject: [PATCH] [lwnet] Add parameter for crop size

---
 src/mednet/libs/segmentation/config/models/lwnet.py | 1 +
 src/mednet/libs/segmentation/models/lwnet.py        | 5 ++++-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
index 9cb08ea3..09af9f84 100644
--- a/src/mednet/libs/segmentation/config/models/lwnet.py
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -23,4 +23,5 @@ model = LittleWNet(
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=max_lr),
     augmentation_transforms=[],
+    crop_size=544,
 )
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 6cad67c9..d8e5b2cb 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule):
         applied on the input **before** it is fed into the network.
     num_classes
         Number of outputs (classes) for this model.
+    crop_size
+        The size of the image after center cropping.
     """
 
     def __init__(
@@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
+        crop_size: int = 544,
     ):
         super().__init__()
 
         self.name = "lwnet"
         self.num_classes = num_classes
 
-        self.model_transforms = [CenterCrop(size=(544, 544))]
+        self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
 
         self._train_loss = train_loss
         self._validation_loss = (
-- 
GitLab