Skip to content
Snippets Groups Projects
Commit dcf143b2 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[transforms] Add transforms to crop and resize images

parent d7f27c9e
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -75,7 +75,6 @@ def train( ...@@ -75,7 +75,6 @@ def train(
setup_datamodule( setup_datamodule(
datamodule, datamodule,
model, model,
batch_size,
drop_incomplete_batch, drop_incomplete_batch,
cache_samples, cache_samples,
parallel, parallel,
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB."""
import numpy import numpy
import torch import torch
...@@ -9,6 +8,68 @@ import torch.nn ...@@ -9,6 +8,68 @@ import torch.nn
import torchvision.transforms.functional import torchvision.transforms.functional
def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Square crop image to the boundaries of a boolean mask.
Parameters
----------
img
The image to crop.
mask
The boolean mask to use for cropping.
Returns
-------
The cropped image.
"""
if img.shape[-2:] != mask.shape[-2:]:
raise ValueError(
f"Image and mask must have the same size: {img.shape[-2:]} != {mask.shape[-2:]}"
)
h, w = img.shape[-2:]
flat_mask = mask.flatten()
top = flat_mask.nonzero()[0] // w
bottom = h - (torch.flip(flat_mask, dims=(0,)).nonzero()[0] // w)
flat_transposed_mask = torch.transpose(mask, 1, 2).flatten()
left = flat_transposed_mask.nonzero()[0] // h
right = w - (torch.flip(flat_transposed_mask, dims=(0,)).nonzero()[0] // h)
return img[:, top:bottom, left:right]
def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
"""Resize image based on the longest side while keeping the aspect ratio.
Parameters
----------
tensor
The tensor to resize.
max_side
The new length of the largest side.
Returns
-------
The resized image.
"""
if max_side <= 0:
raise ValueError(f"The new max side ({max_side}) must be positive.")
height, width = tensor.shape[-2:]
aspect_ratio = float(height) / float(width)
if height >= width:
new_size = (max_side, int(max_side / aspect_ratio))
else:
new_size = (int(max_side * aspect_ratio), max_side)
return torchvision.transforms.Resize(new_size, antialias=True)(tensor)
def square_center_pad(img: torch.Tensor) -> torch.Tensor: def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Return a squared version of the image, centered on a canvas padded with """Return a squared version of the image, centered on a canvas padded with
zeros. zeros.
...@@ -132,6 +193,23 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: ...@@ -132,6 +193,23 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img) return torchvision.transforms.functional.rgb_to_grayscale(img)
class ResizeMaxSide(torch.nn.Module):
"""Resize image on the longest side while keeping the aspect ratio.
Parameters
----------
max_side
The new length of the largest side.
"""
def __init__(self, max_side: int):
super().__init__()
self.max_side = max_side
def forward(self, img: torch.Tensor) -> torch.Tensor:
return resize_max_side(img, self.max_side)
class SquareCenterPad(torch.nn.Module): class SquareCenterPad(torch.nn.Module):
"""Transform to a squared version of the image, centered on a canvas padded """Transform to a squared version of the image, centered on a canvas padded
with zeros. with zeros.
......
...@@ -275,7 +275,6 @@ def load_checkpoint(checkpoint_file, datamodule, model): ...@@ -275,7 +275,6 @@ def load_checkpoint(checkpoint_file, datamodule, model):
def setup_datamodule( def setup_datamodule(
datamodule, datamodule,
model, model,
batch_size,
drop_incomplete_batch, drop_incomplete_batch,
cache_samples, cache_samples,
parallel, parallel,
......
...@@ -5,8 +5,51 @@ ...@@ -5,8 +5,51 @@
import numpy import numpy
import PIL.Image import PIL.Image
import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
from mednet.libs.common.data.augmentations import ElasticDeformation from mednet.libs.common.data.augmentations import ElasticDeformation
from mednet.libs.common.models.transforms import (
crop_image_to_mask,
resize_max_side,
)
def test_crop_mask():
original_tensor_size = (3, 50, 100)
original_mask_size = (1, 50, 100)
slice_ = (slice(None), slice(10, 30), slice(50, 70))
tensor = torch.rand(original_tensor_size)
mask = torch.zeros(original_mask_size)
mask[slice_] = 1
cropped_tensor = crop_image_to_mask(tensor, mask)
assert cropped_tensor.shape == (3, 20, 20)
assert torch.all(cropped_tensor.eq(tensor[slice_]))
def test_resize_max_size():
original_size = (3, 50, 100)
original_ratio = original_size[1] / original_size[2]
new_max_side = 120
tensor = torch.rand(original_size)
resized_tensor = resize_max_side(tensor, new_max_side)
resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2]
assert original_ratio == resized_ratio
transposed_tensor = tensor.transpose(1, 2)
resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side)
inv_ratio = 1 / (
resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2]
)
assert original_ratio == inv_ratio
assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2]
assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1]
def test_elastic_deformation(datadir): def test_elastic_deformation(datadir):
......
...@@ -11,6 +11,7 @@ import PIL.Image ...@@ -11,6 +11,7 @@ import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import ( from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader, SegmentationRawDataLoader as _SegmentationRawDataLoader,
) )
...@@ -50,25 +51,26 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -50,25 +51,26 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation. The sample representation.
""" """
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( image = to_tensor(
mode="RGB" PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
) )
tensor = tv_tensors.Image(to_tensor(image)) target = to_tensor(
target = tv_tensors.Image( PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
to_tensor( mode="1", dither=None
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
) )
) )
mask = tv_tensors.Mask( mask = to_tensor(
to_tensor( PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( mode="1", dither=None
mode="1", dither=None
)
) )
) )
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
...@@ -11,6 +11,7 @@ import PIL.Image ...@@ -11,6 +11,7 @@ import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import ( from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader, SegmentationRawDataLoader as _SegmentationRawDataLoader,
) )
...@@ -51,25 +52,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -51,25 +52,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation. The sample representation.
""" """
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( image = to_tensor(
mode="RGB" PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
) )
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image( target = to_tensor(
to_tensor( PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( mode="1", dither=None
mode="1", dither=None
)
) )
) )
mask = tv_tensors.Mask(
to_tensor( mask = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None mode="1", dither=None
)
) )
) )
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
...@@ -12,6 +12,7 @@ import pkg_resources ...@@ -12,6 +12,7 @@ import pkg_resources
from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import ( from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader, SegmentationRawDataLoader as _SegmentationRawDataLoader,
) )
...@@ -53,25 +54,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): ...@@ -53,25 +54,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation. The sample representation.
""" """
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( image = to_tensor(
mode="RGB" PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
) )
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image( target = to_tensor(
to_tensor( PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( mode="1", dither=None
mode="1", dither=None
)
) )
) )
mask = tv_tensors.Mask(
to_tensor( mask = to_tensor(
PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert( PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
mode="1", dither=None mode="1", dither=None
)
) )
) )
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
...@@ -21,8 +21,8 @@ import torch ...@@ -21,8 +21,8 @@ import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from torchvision.transforms.v2 import CenterCrop
from .separate import separate from .separate import separate
...@@ -334,7 +334,12 @@ class LittleWNet(Model): ...@@ -334,7 +334,12 @@ class LittleWNet(Model):
self.name = "lwnet" self.name = "lwnet"
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.unet1 = LittleUNet( self.unet1 = LittleUNet(
in_c=3, in_c=3,
...@@ -360,16 +365,15 @@ class LittleWNet(Model): ...@@ -360,16 +365,15 @@ class LittleWNet(Model):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]
ground_truths = batch[1]["target"] ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] masks = batch[1]["mask"]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]
ground_truths = batch[1]["target"] ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] masks = batch[1]["mask"]
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks) return self._validation_loss(outputs, ground_truths, masks)
......
...@@ -61,7 +61,6 @@ def train( ...@@ -61,7 +61,6 @@ def train(
setup_datamodule( setup_datamodule(
datamodule, datamodule,
model, model,
batch_size,
drop_incomplete_batch, drop_incomplete_batch,
cache_samples, cache_samples,
parallel, parallel,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment