diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py
index f926f620d5ebb1b046c9f6adc0fe2130d08dc83a..5fa2d68fe8595504cde5062226df3053c7133ae0 100644
--- a/src/mednet/libs/common/models/transforms.py
+++ b/src/mednet/libs/common/models/transforms.py
@@ -14,7 +14,7 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
     Parameters
     ----------
     img
-        The image to crop.
+        The image to crop, of shape channels x height x width.
     mask
         The boolean mask to use for cropping.
 
@@ -41,6 +41,25 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
     return img[:, top:bottom, left:right]
 
 
+def crop_multiple_images_to_mask(
+    images: list[torch.Tensor], mask: torch.Tensor
+) -> list[torch.Tensor]:
+    """Apply crop_images_to_mask on multiple images.
+
+    Parameters
+    ----------
+    images
+        List of images to crop, of shape channels x height x width.
+    mask
+        The boolean mask to use for cropping.
+
+    Returns
+    -------
+        A list of cropped images.
+    """
+    return [crop_image_to_mask(img, mask) for img in images]
+
+
 def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
     """Resize image based on the longest side while keeping the aspect ratio.