Skip to content
Snippets Groups Projects
Commit 5e5f5b8a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Apply square-center-padding on all relevant model transforms (closes #39)

parent d43b2c49
No related branches found
No related tags found
No related merge requests found
Pipeline #84048 failed
......@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence
from .separate import separate
from .transforms import RGB
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule):
self.num_classes = num_classes
self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
RGB(),
]
......
......@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence
from .separate import separate
from .transforms import RGB
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -71,8 +71,8 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121"
self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
RGB(),
]
......
......@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence
from .separate import separate
from .transforms import Grayscale
from .transforms import Grayscale, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -72,10 +72,10 @@ class Pasa(pl.LightningModule):
self.name = "pasa"
self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [
torchvision.transforms.Resize(512, antialias=True),
Grayscale(),
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
]
self._train_loss = train_loss
......
......@@ -3,11 +3,45 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB."""
import numpy
import torch
import torch.nn
import torchvision.transforms.functional
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Returns a squared version of the image, centered on a canvas padded with
zeros.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
img
transformed tensor, guaranteed to be square (ie. equal height and
width).
"""
height, width = img.shape[-2:]
maxdim = numpy.max([height, width])
# padding
left = (maxdim - width) // 2
top = (maxdim - height) // 2
right = maxdim - width - left
bottom = maxdim - height - top
return torchvision.transforms.functional.pad(
img, [left, top, right, bottom], 0, "constant"
)
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Convert an image in grayscale to RGB.
......@@ -94,6 +128,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img)
class SquareCenterPad(torch.nn.Module):
"""Transforms to a squared version of the image, centered on a canvas
padded with zeros."""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return square_center_pad(img)
class RGB(torch.nn.Module):
"""Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment