diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 0b2d9e180f0333da40f2d568ca2202b3ddb43b9b..d2635970f35497f6274b5ea87488f61e8c6a3a30 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule): self.num_classes = num_classes self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 2bd6a2dd01ecf8389b1f32a4e783b0c4315af26e..91f7c336457d4bda68342be49cb77bb2b22d47c3 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -71,8 +71,8 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 9007ab1b8d7e5c79b29bc65d4a996783bfc7889e..38ad0218ffd8a0a35573896175d32d9ff6a521af 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import Grayscale +from .transforms import Grayscale, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -72,10 +72,10 @@ class Pasa(pl.LightningModule): self.name = "pasa" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Resize(512, antialias=True), Grayscale(), + SquareCenterPad(), + torchvision.transforms.Resize(512, antialias=True), ] self._train_loss = train_loss diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 4ee73548a5b458ed6e5e043fb5922dcb752cb126..afb158ac7140de1554250a05751acf47e3a90b02 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -3,11 +3,45 @@ # SPDX-License-Identifier: GPL-3.0-or-later """A transform that turns grayscale images to RGB.""" +import numpy import torch import torch.nn import torchvision.transforms.functional +def square_center_pad(img: torch.Tensor) -> torch.Tensor: + """Returns a squared version of the image, centered on a canvas padded with + zeros. + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor, guaranteed to be square (ie. equal height and + width). + """ + + height, width = img.shape[-2:] + maxdim = numpy.max([height, width]) + + # padding + left = (maxdim - width) // 2 + top = (maxdim - height) // 2 + right = maxdim - width - left + bottom = maxdim - height - top + + return torchvision.transforms.functional.pad( + img, [left, top, right, bottom], 0, "constant" + ) + + def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: """Convert an image in grayscale to RGB. @@ -94,6 +128,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: return torchvision.transforms.functional.rgb_to_grayscale(img) +class SquareCenterPad(torch.nn.Module): + """Transforms to a squared version of the image, centered on a canvas + padded with zeros.""" + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return square_center_pad(img) + + class RGB(torch.nn.Module): """Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform."""