Skip to content
Snippets Groups Projects
Commit d2d83ca3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Apply square-center-padding on all relevant model transforms (closes #39)

parent 1a54dbf4
No related branches found
No related tags found
1 merge request!16Make square centre-padding a model transform
Pipeline #84055 failed
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import RGB from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule): ...@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule):
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = [ self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), RGB(),
] ]
......
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import RGB from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -71,8 +71,8 @@ class Densenet(pl.LightningModule): ...@@ -71,8 +71,8 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121" self.name = "densenet-121"
self.num_classes = num_classes self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), RGB(),
] ]
......
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import Grayscale from .transforms import Grayscale, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -72,10 +72,10 @@ class Pasa(pl.LightningModule): ...@@ -72,10 +72,10 @@ class Pasa(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.num_classes = num_classes self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512, antialias=True),
Grayscale(), Grayscale(),
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
] ]
self._train_loss = train_loss self._train_loss = train_loss
......
...@@ -3,11 +3,45 @@ ...@@ -3,11 +3,45 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB.""" """A transform that turns grayscale images to RGB."""
import numpy
import torch import torch
import torch.nn import torch.nn
import torchvision.transforms.functional import torchvision.transforms.functional
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Returns a squared version of the image, centered on a canvas padded with
zeros.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
img
transformed tensor, guaranteed to be square (ie. equal height and
width).
"""
height, width = img.shape[-2:]
maxdim = numpy.max([height, width])
# padding
left = (maxdim - width) // 2
top = (maxdim - height) // 2
right = maxdim - width - left
bottom = maxdim - height - top
return torchvision.transforms.functional.pad(
img, [left, top, right, bottom], 0, "constant"
)
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Convert an image in grayscale to RGB. """Convert an image in grayscale to RGB.
...@@ -94,6 +128,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: ...@@ -94,6 +128,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img) return torchvision.transforms.functional.rgb_to_grayscale(img)
class SquareCenterPad(torch.nn.Module):
"""Transforms to a squared version of the image, centered on a canvas
padded with zeros."""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return square_center_pad(img)
class RGB(torch.nn.Module): class RGB(torch.nn.Module):
"""Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform.""" """Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment