Skip to content
Snippets Groups Projects
Commit 3172ede6 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.common.models.transforms] Maximise use of torchvision transforms; Closes...

[libs.common.models.transforms] Maximise use of torchvision transforms; Closes #86 as well concluding previous commit
parent 821f5c6f
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89346 failed
......@@ -21,7 +21,7 @@ model = mednet.libs.classification.models.pasa.Pasa(
optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=8e-5),
model_transforms=[
mednet.libs.common.models.transforms.Grayscale(),
torchvision.transforms.Grayscale(),
mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
],
......
......@@ -2,10 +2,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
import numpy
import torch
import torch.nn
import torchvision.transforms.functional
import torchvision.transforms.v2
import torchvision.tv_tensors
def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
......@@ -41,25 +45,6 @@ def crop_image_to_mask(img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return img[:, top:bottom, left:right]
def crop_multiple_images_to_mask(
images: list[torch.Tensor], mask: torch.Tensor
) -> list[torch.Tensor]:
"""Apply crop_images_to_mask on multiple images.
Parameters
----------
images
List of images to crop, of shape channels x height x width.
mask
The boolean mask to use for cropping.
Returns
-------
A list of cropped images.
"""
return [crop_image_to_mask(img, mask) for img in images]
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Return a squared version of the image, centered on a canvas padded with
zeros.
......@@ -93,96 +78,6 @@ def square_center_pad(img: torch.Tensor) -> torch.Tensor:
)
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Convert an image in grayscale to RGB.
If the image is already in RGB format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in grayscale format
(number of color channels = 1), then replicate it to obtain 3 color channels
(a new copy is returned in this case).
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
torch.Tensor
Transformed tensor with 3 identical color channels.
"""
if img.ndim < 3:
raise TypeError(
f"Input image tensor should have at least 3 dimensions, "
f"but found {img.ndim}. If a grayscale image was provided, "
f"ensure to include a channel dimension of size 1 ( i.e: "
f"[1, height, width]).",
)
if img.shape[-3] not in (1, 3):
raise TypeError(
f"Input image tensor should have 1 or 3 color channels,"
f"but found {img.shape[-3]}.",
)
if img.shape[-3] == 3:
return img
# it is a grayscale image - repeat the image 3 times
repetitions = img.dim() * [1]
repetitions[-3] = 3
return img.repeat(*repetitions)
def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
"""Convert an image in RGB to grayscale.
If the image is already in grayscale format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in RGB format
(number of color channels = 3), then compresses the color channels into
a single grayscale channel following this equation:
.. math::
grayscale = (0.2989 * r + 0.587 * g + 0.114 * b)
A new tensor is returned in this case.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
torch.Tensor
Transformed tensor with a single (grayscale) color channel.
"""
if img.ndim < 3:
raise TypeError(
f"Input image tensor should have at least 3 dimensions, "
f"but found {img.ndim}. If a grayscale image was provided, "
f"ensure to include a channel dimension of size 1 ( i.e: "
f"[1, height, width]).",
)
if img.shape[-3] not in (1, 3):
raise TypeError(
f"Input image tensor should have 1 or 3 planes,"
f"but found {img.shape[-3]}",
)
if img.shape[-3] == 1:
return img
# it is an RGB image - use torchvision implementation
return torchvision.transforms.functional.rgb_to_grayscale(img)
class SquareCenterPad(torch.nn.Module):
"""Transform to a squared version of the image, centered on a canvas padded
with zeros.
......@@ -195,25 +90,21 @@ class SquareCenterPad(torch.nn.Module):
return square_center_pad(img)
class RGB(torch.nn.Module):
"""Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model
transform.
"""
class RGB(torchvision.transforms.v2.Grayscale):
"""Convert images or videos to RGB (if they are already not RGB).
def __init__(self):
super().__init__()
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
def forward(self, img: torch.Tensor) -> torch.Tensor:
return grayscale_to_rgb(img)
.. note::
class Grayscale(torch.nn.Module):
"""Wrapper class around :py:func:`rgb_to_grayscale` to be used as a model
transform.
Copied from torchvision v0.18.1 source code.
"""
def __init__(self):
super().__init__()
super().__init__(num_output_channels=3)
def forward(self, img: torch.Tensor) -> torch.Tensor:
return rgb_to_grayscale(img)
def forward(self, inpt: typing.Any) -> typing.Any: # type: ignore
if inpt.shape[-3] >= 3:
return inpt
return super().forward(inpt)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment