Skip to content
Snippets Groups Projects
Commit dcf143b2 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[transforms] Add transforms to crop and resize images

parent d7f27c9e
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -75,7 +75,6 @@ def train(
setup_datamodule(
datamodule,
model,
batch_size,
drop_incomplete_batch,
cache_samples,
parallel,
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB."""
import numpy
import torch
......@@ -9,6 +8,68 @@ import torch.nn
import torchvision.transforms.functional
def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Square crop image to the boundaries of a boolean mask.
Parameters
----------
img
The image to crop.
mask
The boolean mask to use for cropping.
Returns
-------
The cropped image.
"""
if img.shape[-2:] != mask.shape[-2:]:
raise ValueError(
f"Image and mask must have the same size: {img.shape[-2:]} != {mask.shape[-2:]}"
)
h, w = img.shape[-2:]
flat_mask = mask.flatten()
top = flat_mask.nonzero()[0] // w
bottom = h - (torch.flip(flat_mask, dims=(0,)).nonzero()[0] // w)
flat_transposed_mask = torch.transpose(mask, 1, 2).flatten()
left = flat_transposed_mask.nonzero()[0] // h
right = w - (torch.flip(flat_transposed_mask, dims=(0,)).nonzero()[0] // h)
return img[:, top:bottom, left:right]
def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
"""Resize image based on the longest side while keeping the aspect ratio.
Parameters
----------
tensor
The tensor to resize.
max_side
The new length of the largest side.
Returns
-------
The resized image.
"""
if max_side <= 0:
raise ValueError(f"The new max side ({max_side}) must be positive.")
height, width = tensor.shape[-2:]
aspect_ratio = float(height) / float(width)
if height >= width:
new_size = (max_side, int(max_side / aspect_ratio))
else:
new_size = (int(max_side * aspect_ratio), max_side)
return torchvision.transforms.Resize(new_size, antialias=True)(tensor)
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Return a squared version of the image, centered on a canvas padded with
zeros.
......@@ -132,6 +193,23 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img)
class ResizeMaxSide(torch.nn.Module):
"""Resize image on the longest side while keeping the aspect ratio.
Parameters
----------
max_side
The new length of the largest side.
"""
def __init__(self, max_side: int):
super().__init__()
self.max_side = max_side
def forward(self, img: torch.Tensor) -> torch.Tensor:
return resize_max_side(img, self.max_side)
class SquareCenterPad(torch.nn.Module):
"""Transform to a squared version of the image, centered on a canvas padded
with zeros.
......
......@@ -275,7 +275,6 @@ def load_checkpoint(checkpoint_file, datamodule, model):
def setup_datamodule(
datamodule,
model,
batch_size,
drop_incomplete_batch,
cache_samples,
parallel,
......
......@@ -5,8 +5,51 @@
import numpy
import PIL.Image
import torch
import torchvision.transforms.functional as F # noqa: N812
from mednet.libs.common.data.augmentations import ElasticDeformation
from mednet.libs.common.models.transforms import (
crop_image_to_mask,
resize_max_side,
)
def test_crop_mask():
original_tensor_size = (3, 50, 100)
original_mask_size = (1, 50, 100)
slice_ = (slice(None), slice(10, 30), slice(50, 70))
tensor = torch.rand(original_tensor_size)
mask = torch.zeros(original_mask_size)
mask[slice_] = 1
cropped_tensor = crop_image_to_mask(tensor, mask)
assert cropped_tensor.shape == (3, 20, 20)
assert torch.all(cropped_tensor.eq(tensor[slice_]))
def test_resize_max_size():
original_size = (3, 50, 100)
original_ratio = original_size[1] / original_size[2]
new_max_side = 120
tensor = torch.rand(original_size)
resized_tensor = resize_max_side(tensor, new_max_side)
resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2]
assert original_ratio == resized_ratio
transposed_tensor = tensor.transpose(1, 2)
resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side)
inv_ratio = 1 / (
resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2]
)
assert original_ratio == inv_ratio
assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2]
assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1]
def test_elastic_deformation(datadir):
......
......@@ -11,6 +11,7 @@ import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -50,25 +51,26 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
image = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
)
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
target = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
)
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
mask = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
......@@ -11,6 +11,7 @@ import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -51,25 +52,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
image = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
)
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
target = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
)
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
mask = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
......@@ -12,6 +12,7 @@ import pkg_resources
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
......@@ -53,25 +54,28 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
The sample representation.
"""
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
image = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
)
tensor = tv_tensors.Image(to_tensor(image))
target = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
target = to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
)
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
mode="1", dither=None
)
mask = to_tensor(
PIL.Image.open(Path(self._pkg_path) / str(sample[2])).convert(
mode="1", dither=None
)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
......
......@@ -21,8 +21,8 @@ import torch
import torch.nn
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from torchvision.transforms.v2 import CenterCrop
from .separate import separate
......@@ -334,7 +334,12 @@ class LittleWNet(Model):
self.name = "lwnet"
self.num_classes = num_classes
self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.unet1 = LittleUNet(
in_c=3,
......@@ -360,16 +365,15 @@ class LittleWNet(Model):
def training_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
masks = batch[1]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
masks = batch[1]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
......
......@@ -61,7 +61,6 @@ def train(
setup_datamodule(
datamodule,
model,
batch_size,
drop_incomplete_batch,
cache_samples,
parallel,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment