From 1aafeb43b99465f72725e6bbadfc0bde4ed0ec43 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 5 Jun 2024 10:43:47 +0200
Subject: [PATCH] [transforms] Transform to crop multiple images at once

---
 src/mednet/libs/common/models/transforms.py | 21 ++++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py
index f926f620..5fa2d68f 100644
--- a/src/mednet/libs/common/models/transforms.py
+++ b/src/mednet/libs/common/models/transforms.py
@@ -14,7 +14,7 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
     Parameters
     ----------
     img
-        The image to crop.
+        The image to crop, of shape channels x height x width.
     mask
         The boolean mask to use for cropping.
 
@@ -41,6 +41,25 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
     return img[:, top:bottom, left:right]
 
 
+def crop_multiple_images_to_mask(
+    images: list[torch.Tensor], mask: torch.Tensor
+) -> list[torch.Tensor]:
+    """Apply crop_images_to_mask on multiple images.
+
+    Parameters
+    ----------
+    images
+        List of images to crop, of shape channels x height x width.
+    mask
+        The boolean mask to use for cropping.
+
+    Returns
+    -------
+        A list of cropped images.
+    """
+    return [crop_image_to_mask(img, mask) for img in images]
+
+
 def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
     """Resize image based on the longest side while keeping the aspect ratio.
 
-- 
GitLab