diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 8483f667eb71197a191345901e59be02ec5171cd..d1a1376964edc5fb0800b926955996d8429c7b9d 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -75,7 +75,6 @@ def train(
     setup_datamodule(
         datamodule,
         model,
-        batch_size,
         drop_incomplete_batch,
         cache_samples,
         parallel,
diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py
index 132861cce9c71f3af9e6c1dc894b22bd63b805c7..f926f620d5ebb1b046c9f6adc0fe2130d08dc83a 100644
--- a/src/mednet/libs/common/models/transforms.py
+++ b/src/mednet/libs/common/models/transforms.py
@@ -1,7 +1,6 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""A transform that turns grayscale images to RGB."""
 
 import numpy
 import torch
@@ -9,6 +8,68 @@ import torch.nn
 import torchvision.transforms.functional
 
 
+def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+    """Square crop image to the boundaries of a boolean mask.
+
+    Parameters
+    ----------
+    img
+        The image to crop.
+    mask
+        The boolean mask to use for cropping.
+
+    Returns
+    -------
+        The cropped image.
+    """
+
+    if img.shape[-2:] != mask.shape[-2:]:
+        raise ValueError(
+            f"Image and mask must have the same size: {img.shape[-2:]} != {mask.shape[-2:]}"
+        )
+
+    h, w = img.shape[-2:]
+
+    flat_mask = mask.flatten()
+    top = flat_mask.nonzero()[0] // w
+    bottom = h - (torch.flip(flat_mask, dims=(0,)).nonzero()[0] // w)
+
+    flat_transposed_mask = torch.transpose(mask, 1, 2).flatten()
+    left = flat_transposed_mask.nonzero()[0] // h
+    right = w - (torch.flip(flat_transposed_mask, dims=(0,)).nonzero()[0] // h)
+
+    return img[:, top:bottom, left:right]
+
+
+def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
+    """Resize image based on the longest side while keeping the aspect ratio.
+
+    Parameters
+    ----------
+    tensor
+        The tensor to resize.
+    max_side
+        The new length of the largest side.
+
+    Returns
+    -------
+        The resized image.
+    """
+
+    if max_side <= 0:
+        raise ValueError(f"The new max side ({max_side}) must be positive.")
+
+    height, width = tensor.shape[-2:]
+    aspect_ratio = float(height) / float(width)
+
+    if height >= width:
+        new_size = (max_side, int(max_side / aspect_ratio))
+    else:
+        new_size = (int(max_side * aspect_ratio), max_side)
+
+    return torchvision.transforms.Resize(new_size, antialias=True)(tensor)
+
+
 def square_center_pad(img: torch.Tensor) -> torch.Tensor:
     """Return a squared version of the image, centered on a canvas padded with
     zeros.
@@ -132,6 +193,23 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
     return torchvision.transforms.functional.rgb_to_grayscale(img)
 
 
+class ResizeMaxSide(torch.nn.Module):
+    """Resize image on the longest side while keeping the aspect ratio.
+
+    Parameters
+    ----------
+    max_side
+        The new length of the largest side.
+    """
+
+    def __init__(self, max_side: int):
+        super().__init__()
+        self.max_side = max_side
+
+    def forward(self, img: torch.Tensor) -> torch.Tensor:
+        return resize_max_side(img, self.max_side)
+
+
 class SquareCenterPad(torch.nn.Module):
     """Transform to a squared version of the image, centered on a canvas padded
     with zeros.
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index 774d485316e068af2192cb2d846a84f5612a6e88..b806609a84634559f6cec6a28cf6ef69b76b21dd 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -275,7 +275,6 @@ def load_checkpoint(checkpoint_file, datamodule, model):
 def setup_datamodule(
     datamodule,
     model,
-    batch_size,
     drop_incomplete_batch,
     cache_samples,
     parallel,
diff --git a/src/mednet/libs/common/tests/test_transforms.py b/src/mednet/libs/common/tests/test_transforms.py
index 6b80387669297305a401dec0ece7e67f5c2d0566..a8f1d71479c7c954c7459f6b85ddbb0d00e0bef5 100644
--- a/src/mednet/libs/common/tests/test_transforms.py
+++ b/src/mednet/libs/common/tests/test_transforms.py
@@ -5,8 +5,51 @@
 
 import numpy
 import PIL.Image
+import torch
 import torchvision.transforms.functional as F  # noqa: N812
 from mednet.libs.common.data.augmentations import ElasticDeformation
+from mednet.libs.common.models.transforms import (
+    crop_image_to_mask,
+    resize_max_side,
+)
+
+
+def test_crop_mask():
+    original_tensor_size = (3, 50, 100)
+    original_mask_size = (1, 50, 100)
+    slice_ = (slice(None), slice(10, 30), slice(50, 70))
+
+    tensor = torch.rand(original_tensor_size)
+    mask = torch.zeros(original_mask_size)
+    mask[slice_] = 1
+
+    cropped_tensor = crop_image_to_mask(tensor, mask)
+
+    assert cropped_tensor.shape == (3, 20, 20)
+    assert torch.all(cropped_tensor.eq(tensor[slice_]))
+
+
+def test_resize_max_size():
+    original_size = (3, 50, 100)
+    original_ratio = original_size[1] / original_size[2]
+
+    new_max_side = 120
+    tensor = torch.rand(original_size)
+
+    resized_tensor = resize_max_side(tensor, new_max_side)
+    resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2]
+    assert original_ratio == resized_ratio
+
+    transposed_tensor = tensor.transpose(1, 2)
+
+    resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side)
+    inv_ratio = 1 / (
+        resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2]
+    )
+    assert original_ratio == inv_ratio
+
+    assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2]
+    assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1]
 
 
 def test_elastic_deformation(datadir):
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index d7f2890ac96b8c8b4e80429410433927e706e6e0..b3a2a76730d5be54881d57dccf08e941dceff260 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -11,6 +11,7 @@ import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import JSONDatabaseSplit
 from mednet.libs.common.data.typing import DatabaseSplit, Sample
+from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
@@ -50,25 +51,26 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             The sample representation.
         """
 
-        image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
-            mode="RGB"
+        image = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
+                mode="RGB"
+            )
         )
-        tensor = tv_tensors.Image(to_tensor(image))
-        target = tv_tensors.Image(
-            to_tensor(
-                PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
-                    mode="1", dither=None
-                )
+        target = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
+                mode="1", dither=None
             )
         )
-        mask = tv_tensors.Mask(
-            to_tensor(
-                PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
-                    mode="1", dither=None
-                )
+        mask = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
+                mode="1", dither=None
             )
         )
 
+        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
+
         return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
 
 
diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
index 3c810f59fa6a61232b25607206b6fb12945d9f5b..72500523a933a0bd60e50d9f72ea296281799ca4 100644
--- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
@@ -11,6 +11,7 @@ import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import JSONDatabaseSplit
 from mednet.libs.common.data.typing import DatabaseSplit, Sample
+from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
@@ -51,25 +52,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             The sample representation.
         """
 
-        image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
-            mode="RGB"
+        image = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
+                mode="RGB"
+            )
         )
-        tensor = tv_tensors.Image(to_tensor(image))
-        target = tv_tensors.Image(
-            to_tensor(
-                PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
-                    mode="1", dither=None
-                )
+
+        target = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
+                mode="1", dither=None
             )
         )
-        mask = tv_tensors.Mask(
-            to_tensor(
-                PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
-                    mode="1", dither=None
-                )
+
+        mask = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
+                mode="1", dither=None
             )
         )
 
+        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
+
         return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
 
 
diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
index 5fc90042b6c34d06ae1e780e27c3cb30f16698b8..d037a6947169d7c3659f17f5064fe921f843f5ef 100644
--- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
@@ -12,6 +12,7 @@ import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import JSONDatabaseSplit
 from mednet.libs.common.data.typing import DatabaseSplit, Sample
+from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
@@ -53,25 +54,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             The sample representation.
         """
 
-        image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
-            mode="RGB"
+        image = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
+                mode="RGB"
+            )
         )
-        tensor = tv_tensors.Image(to_tensor(image))
-        target = tv_tensors.Image(
-            to_tensor(
-                PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
-                    mode="1", dither=None
-                )
+
+        target = to_tensor(
+            PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
+                mode="1", dither=None
             )
         )
-        mask = tv_tensors.Mask(
-            to_tensor(
-                PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
-                    mode="1", dither=None
-                )
+
+        mask = to_tensor(
+            PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
+                mode="1", dither=None
             )
         )
 
+        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
+
         return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
 
 
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index a5cfc9bd711c552dece3a5724b60290dff2603bc..8bae2956f333216f75a21e4109fdf63fc1a93851 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -21,8 +21,8 @@ import torch
 import torch.nn
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
+from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
-from torchvision.transforms.v2 import CenterCrop
 
 from .separate import separate
 
@@ -334,7 +334,12 @@ class LittleWNet(Model):
         self.name = "lwnet"
         self.num_classes = num_classes
 
-        self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
+        resize_transform = ResizeMaxSide(crop_size)
+
+        self.model_transforms = [
+            resize_transform,
+            SquareCenterPad(),
+        ]
 
         self.unet1 = LittleUNet(
             in_c=3,
@@ -360,16 +365,15 @@ class LittleWNet(Model):
     def training_step(self, batch, batch_idx):
         images = batch[0]
         ground_truths = batch[1]["target"]
-        masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
+        masks = batch[1]["mask"]
 
         outputs = self(self._augmentation_transforms(images))
-
         return self._train_loss(outputs, ground_truths, masks)
 
     def validation_step(self, batch, batch_idx):
         images = batch[0]
         ground_truths = batch[1]["target"]
-        masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
+        masks = batch[1]["mask"]
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index fc726a14cbf5f9efbdcc66e5b300785b5a676814..41467b896c63136d0b6905151645d63d5346c755 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -61,7 +61,6 @@ def train(
     setup_datamodule(
         datamodule,
         model,
-        batch_size,
         drop_incomplete_batch,
         cache_samples,
         parallel,