diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 8483f667eb71197a191345901e59be02ec5171cd..d1a1376964edc5fb0800b926955996d8429c7b9d 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -75,7 +75,6 @@ def train( setup_datamodule( datamodule, model, - batch_size, drop_incomplete_batch, cache_samples, parallel, diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py index 132861cce9c71f3af9e6c1dc894b22bd63b805c7..f926f620d5ebb1b046c9f6adc0fe2130d08dc83a 100644 --- a/src/mednet/libs/common/models/transforms.py +++ b/src/mednet/libs/common/models/transforms.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""A transform that turns grayscale images to RGB.""" import numpy import torch @@ -9,6 +8,68 @@ import torch.nn import torchvision.transforms.functional +def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Square crop image to the boundaries of a boolean mask. + + Parameters + ---------- + img + The image to crop. + mask + The boolean mask to use for cropping. + + Returns + ------- + The cropped image. + """ + + if img.shape[-2:] != mask.shape[-2:]: + raise ValueError( + f"Image and mask must have the same size: {img.shape[-2:]} != {mask.shape[-2:]}" + ) + + h, w = img.shape[-2:] + + flat_mask = mask.flatten() + top = flat_mask.nonzero()[0] // w + bottom = h - (torch.flip(flat_mask, dims=(0,)).nonzero()[0] // w) + + flat_transposed_mask = torch.transpose(mask, 1, 2).flatten() + left = flat_transposed_mask.nonzero()[0] // h + right = w - (torch.flip(flat_transposed_mask, dims=(0,)).nonzero()[0] // h) + + return img[:, top:bottom, left:right] + + +def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: + """Resize image based on the longest side while keeping the aspect ratio. + + Parameters + ---------- + tensor + The tensor to resize. + max_side + The new length of the largest side. + + Returns + ------- + The resized image. + """ + + if max_side <= 0: + raise ValueError(f"The new max side ({max_side}) must be positive.") + + height, width = tensor.shape[-2:] + aspect_ratio = float(height) / float(width) + + if height >= width: + new_size = (max_side, int(max_side / aspect_ratio)) + else: + new_size = (int(max_side * aspect_ratio), max_side) + + return torchvision.transforms.Resize(new_size, antialias=True)(tensor) + + def square_center_pad(img: torch.Tensor) -> torch.Tensor: """Return a squared version of the image, centered on a canvas padded with zeros. @@ -132,6 +193,23 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: return torchvision.transforms.functional.rgb_to_grayscale(img) +class ResizeMaxSide(torch.nn.Module): + """Resize image on the longest side while keeping the aspect ratio. + + Parameters + ---------- + max_side + The new length of the largest side. + """ + + def __init__(self, max_side: int): + super().__init__() + self.max_side = max_side + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return resize_max_side(img, self.max_side) + + class SquareCenterPad(torch.nn.Module): """Transform to a squared version of the image, centered on a canvas padded with zeros. diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index 774d485316e068af2192cb2d846a84f5612a6e88..b806609a84634559f6cec6a28cf6ef69b76b21dd 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -275,7 +275,6 @@ def load_checkpoint(checkpoint_file, datamodule, model): def setup_datamodule( datamodule, model, - batch_size, drop_incomplete_batch, cache_samples, parallel, diff --git a/src/mednet/libs/common/tests/test_transforms.py b/src/mednet/libs/common/tests/test_transforms.py index 6b80387669297305a401dec0ece7e67f5c2d0566..a8f1d71479c7c954c7459f6b85ddbb0d00e0bef5 100644 --- a/src/mednet/libs/common/tests/test_transforms.py +++ b/src/mednet/libs/common/tests/test_transforms.py @@ -5,8 +5,51 @@ import numpy import PIL.Image +import torch import torchvision.transforms.functional as F # noqa: N812 from mednet.libs.common.data.augmentations import ElasticDeformation +from mednet.libs.common.models.transforms import ( + crop_image_to_mask, + resize_max_side, +) + + +def test_crop_mask(): + original_tensor_size = (3, 50, 100) + original_mask_size = (1, 50, 100) + slice_ = (slice(None), slice(10, 30), slice(50, 70)) + + tensor = torch.rand(original_tensor_size) + mask = torch.zeros(original_mask_size) + mask[slice_] = 1 + + cropped_tensor = crop_image_to_mask(tensor, mask) + + assert cropped_tensor.shape == (3, 20, 20) + assert torch.all(cropped_tensor.eq(tensor[slice_])) + + +def test_resize_max_size(): + original_size = (3, 50, 100) + original_ratio = original_size[1] / original_size[2] + + new_max_side = 120 + tensor = torch.rand(original_size) + + resized_tensor = resize_max_side(tensor, new_max_side) + resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2] + assert original_ratio == resized_ratio + + transposed_tensor = tensor.transpose(1, 2) + + resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side) + inv_ratio = 1 / ( + resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2] + ) + assert original_ratio == inv_ratio + + assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2] + assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1] def test_elastic_deformation(datadir): diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index d7f2890ac96b8c8b4e80429410433927e706e6e0..b3a2a76730d5be54881d57dccf08e941dceff260 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -11,6 +11,7 @@ import PIL.Image from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, ) @@ -50,25 +51,26 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" + image = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( + mode="RGB" + ) ) - tensor = tv_tensors.Image(to_tensor(image)) - target = tv_tensors.Image( - to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + target = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( + mode="1", dither=None ) ) - mask = tv_tensors.Mask( - to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( - mode="1", dither=None - ) + mask = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( + mode="1", dither=None ) ) + tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) + target = tv_tensors.Image(crop_image_to_mask(target, mask)) + mask = tv_tensors.Mask(crop_image_to_mask(mask, mask)) + return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py index 3c810f59fa6a61232b25607206b6fb12945d9f5b..72500523a933a0bd60e50d9f72ea296281799ca4 100644 --- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py @@ -11,6 +11,7 @@ import PIL.Image from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, ) @@ -51,25 +52,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" + image = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( + mode="RGB" + ) ) - tensor = tv_tensors.Image(to_tensor(image)) - target = tv_tensors.Image( - to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + + target = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( + mode="1", dither=None ) ) - mask = tv_tensors.Mask( - to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( - mode="1", dither=None - ) + + mask = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( + mode="1", dither=None ) ) + tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) + target = tv_tensors.Image(crop_image_to_mask(target, mask)) + mask = tv_tensors.Mask(crop_image_to_mask(mask, mask)) + return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py index 5fc90042b6c34d06ae1e780e27c3cb30f16698b8..d037a6947169d7c3659f17f5064fe921f843f5ef 100644 --- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py @@ -12,6 +12,7 @@ import pkg_resources from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.typing import DatabaseSplit, Sample +from mednet.libs.common.models.transforms import crop_image_to_mask from mednet.libs.segmentation.data.typing import ( SegmentationRawDataLoader as _SegmentationRawDataLoader, ) @@ -53,25 +54,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): The sample representation. """ - image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( - mode="RGB" + image = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( + mode="RGB" + ) ) - tensor = tv_tensors.Image(to_tensor(image)) - target = tv_tensors.Image( - to_tensor( - PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None - ) + + target = to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( + mode="1", dither=None ) ) - mask = tv_tensors.Mask( - to_tensor( - PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert( - mode="1", dither=None - ) + + mask = to_tensor( + PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert( + mode="1", dither=None ) ) + tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) + target = tv_tensors.Image(crop_image_to_mask(target, mask)) + mask = tv_tensors.Mask(crop_image_to_mask(mask, mask)) + return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index a5cfc9bd711c552dece3a5724b60290dff2603bc..8bae2956f333216f75a21e4109fdf63fc1a93851 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -21,8 +21,8 @@ import torch import torch.nn from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss -from torchvision.transforms.v2 import CenterCrop from .separate import separate @@ -334,7 +334,12 @@ class LittleWNet(Model): self.name = "lwnet" self.num_classes = num_classes - self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] self.unet1 = LittleUNet( in_c=3, @@ -360,16 +365,15 @@ class LittleWNet(Model): def training_step(self, batch, batch_idx): images = batch[0] ground_truths = batch[1]["target"] - masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] + masks = batch[1]["mask"] outputs = self(self._augmentation_transforms(images)) - return self._train_loss(outputs, ground_truths, masks) def validation_step(self, batch, batch_idx): images = batch[0] ground_truths = batch[1]["target"] - masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] + masks = batch[1]["mask"] outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index fc726a14cbf5f9efbdcc66e5b300785b5a676814..41467b896c63136d0b6905151645d63d5346c755 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -61,7 +61,6 @@ def train( setup_datamodule( datamodule, model, - batch_size, drop_incomplete_batch, cache_samples, parallel,