diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
index 9cb08ea3b463ed79c9b2c8c4dbc670eae7b9fa64..09af9f842c731bcaa23e284eb292d90134664883 100644
--- a/src/mednet/libs/segmentation/config/models/lwnet.py
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -23,4 +23,5 @@ model = LittleWNet(
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=max_lr),
     augmentation_transforms=[],
+    crop_size=544,
 )
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 6cad67c925c5e9ff0a38527ba328b36242b62ab3..d8e5b2cbd23119eb3d02cd5d2ac9e9342ac3ee88 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule):
         applied on the input **before** it is fed into the network.
     num_classes
         Number of outputs (classes) for this model.
+    crop_size
+        The size of the image after center cropping.
     """
 
     def __init__(
@@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
+        crop_size: int = 544,
     ):
         super().__init__()
 
         self.name = "lwnet"
         self.num_classes = num_classes
 
-        self.model_transforms = [CenterCrop(size=(544, 544))]
+        self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
 
         self._train_loss = train_loss
         self._validation_loss = (