diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 5e5b3e7eee3d810aaa4c8dfed31ad3d9fa5414b5..5177316da7d1b505b324429babfd226b5ae33d51 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -80,6 +80,7 @@ class Alexnet(pl.LightningModule): self.num_classes = num_classes self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 333edb11e2112db9ef1f1bf37b2d333f5b418de6..f6c4d916f1791d83653ed4cd986a4be025de5a43 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import RGB +from .transforms import RGB, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -77,8 +77,8 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ + SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True), RGB(), ] diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index a7b8ae62c0c56bc2c5172058fdd3382c987514a1..21cfc2e1e6345b9318ae5e4a15e9b9a1981f6130 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -15,7 +15,7 @@ import torchvision.transforms from ..data.typing import TransformSequence from .separate import separate -from .transforms import Grayscale +from .transforms import Grayscale, SquareCenterPad from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -78,10 +78,10 @@ class Pasa(pl.LightningModule): self.name = "pasa" self.num_classes = num_classes - # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Resize(512, antialias=True), Grayscale(), + SquareCenterPad(), + torchvision.transforms.Resize(512, antialias=True), ] self._train_loss = train_loss diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py index 8ff0af726aada9256603882dd8a6c0aeb6d34cad..92869c7b8210740c062bd5f67114c26481f96a58 100644 --- a/src/mednet/models/transforms.py +++ b/src/mednet/models/transforms.py @@ -3,11 +3,45 @@ # SPDX-License-Identifier: GPL-3.0-or-later """A transform that turns grayscale images to RGB.""" +import numpy import torch import torch.nn import torchvision.transforms.functional +def square_center_pad(img: torch.Tensor) -> torch.Tensor: + """Returns a squared version of the image, centered on a canvas padded with + zeros. + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor, guaranteed to be square (ie. equal height and + width). + """ + + height, width = img.shape[-2:] + maxdim = numpy.max([height, width]) + + # padding + left = (maxdim - width) // 2 + top = (maxdim - height) // 2 + right = maxdim - width - left + bottom = maxdim - height - top + + return torchvision.transforms.functional.pad( + img, [left, top, right, bottom], 0, "constant" + ) + + def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: """Converts an image in grayscale to RGB. @@ -97,6 +131,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: return torchvision.transforms.functional.rgb_to_grayscale(img) +class SquareCenterPad(torch.nn.Module): + """Transforms to a squared version of the image, centered on a canvas + padded with zeros.""" + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return square_center_pad(img) + + class RGB(torch.nn.Module): """Converts an image in grayscale to RGB.