diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py index f926f620d5ebb1b046c9f6adc0fe2130d08dc83a..5fa2d68fe8595504cde5062226df3053c7133ae0 100644 --- a/src/mednet/libs/common/models/transforms.py +++ b/src/mednet/libs/common/models/transforms.py @@ -14,7 +14,7 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: Parameters ---------- img - The image to crop. + The image to crop, of shape channels x height x width. mask The boolean mask to use for cropping. @@ -41,6 +41,25 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return img[:, top:bottom, left:right] +def crop_multiple_images_to_mask( + images: list[torch.Tensor], mask: torch.Tensor +) -> list[torch.Tensor]: + """Apply crop_images_to_mask on multiple images. + + Parameters + ---------- + images + List of images to crop, of shape channels x height x width. + mask + The boolean mask to use for cropping. + + Returns + ------- + A list of cropped images. + """ + return [crop_image_to_mask(img, mask) for img in images] + + def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: """Resize image based on the longest side while keeping the aspect ratio.