From 44aa532c670404782249a154b3aec5c466d59c65 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 24 Jan 2024 18:29:38 +0100 Subject: [PATCH] [models] Apply square-center-padding on all relevant model transforms (closes #39) --- src/mednet/models/alexnet.py | 3 ++- src/mednet/models/densenet.py | 4 +-- src/mednet/models/pasa.py | 6 ++--- src/mednet/models/transforms.py | 45 +++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 5e5b3e7e..5177316d 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -80,6 +80,7 @@ class Alexnet(pl.LightningModule): self.num_classes = num_classes self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 333edb11..f6c4d916 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -77,8 +77,8 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index a7b8ae62..21cfc2e1 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import Grayscale +from .transforms import Grayscale, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -78,10 +78,10 @@ class Pasa(pl.LightningModule): self.name = "pasa" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Resize(512, antialias=True), Grayscale(), + SquareCenterPad(), + torchvision.transforms.Resize(512, antialias=True), ] self._train_loss = train_loss diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 8ff0af72..92869c7b 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -3,11 +3,45 @@ # SPDX-License-Identifier: GPL-3.0-or-later """A transform that turns grayscale images to RGB.""" +import numpy import torch import torch.nn import torchvision.transforms.functional +def square_center_pad(img: torch.Tensor) -> torch.Tensor: + """Returns a squared version of the image, centered on a canvas padded with + zeros. + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor, guaranteed to be square (ie. equal height and + width). + """ + + height, width = img.shape[-2:] + maxdim = numpy.max([height, width]) + + # padding + left = (maxdim - width) // 2 + top = (maxdim - height) // 2 + right = maxdim - width - left + bottom = maxdim - height - top + + return torchvision.transforms.functional.pad( + img, [left, top, right, bottom], 0, "constant" + ) + + def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: """Converts an image in grayscale to RGB. @@ -97,6 +131,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: return torchvision.transforms.functional.rgb_to_grayscale(img) +class SquareCenterPad(torch.nn.Module): + """Transforms to a squared version of the image, centered on a canvas + padded with zeros.""" + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return square_center_pad(img) + + class RGB(torch.nn.Module): """Converts an image in grayscale to RGB. -- GitLab