Skip to content
Snippets Groups Projects
Commit 44aa532c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Apply square-center-padding on all relevant model transforms (closes #39)

parent 85390375
No related branches found
No related tags found
No related merge requests found
Pipeline #83532 failed
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import RGB from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -80,6 +80,7 @@ class Alexnet(pl.LightningModule): ...@@ -80,6 +80,7 @@ class Alexnet(pl.LightningModule):
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = [ self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), RGB(),
] ]
......
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import RGB from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -77,8 +77,8 @@ class Densenet(pl.LightningModule): ...@@ -77,8 +77,8 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121" self.name = "densenet-121"
self.num_classes = num_classes self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), RGB(),
] ]
......
...@@ -15,7 +15,7 @@ import torchvision.transforms ...@@ -15,7 +15,7 @@ import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate from .separate import separate
from .transforms import Grayscale from .transforms import Grayscale, SquareCenterPad
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -78,10 +78,10 @@ class Pasa(pl.LightningModule): ...@@ -78,10 +78,10 @@ class Pasa(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.num_classes = num_classes self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512, antialias=True),
Grayscale(), Grayscale(),
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
] ]
self._train_loss = train_loss self._train_loss = train_loss
......
...@@ -3,11 +3,45 @@ ...@@ -3,11 +3,45 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB.""" """A transform that turns grayscale images to RGB."""
import numpy
import torch import torch
import torch.nn import torch.nn
import torchvision.transforms.functional import torchvision.transforms.functional
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Returns a squared version of the image, centered on a canvas padded with
zeros.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
img
transformed tensor, guaranteed to be square (ie. equal height and
width).
"""
height, width = img.shape[-2:]
maxdim = numpy.max([height, width])
# padding
left = (maxdim - width) // 2
top = (maxdim - height) // 2
right = maxdim - width - left
bottom = maxdim - height - top
return torchvision.transforms.functional.pad(
img, [left, top, right, bottom], 0, "constant"
)
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Converts an image in grayscale to RGB. """Converts an image in grayscale to RGB.
...@@ -97,6 +131,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: ...@@ -97,6 +131,17 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img) return torchvision.transforms.functional.rgb_to_grayscale(img)
class SquareCenterPad(torch.nn.Module):
"""Transforms to a squared version of the image, centered on a canvas
padded with zeros."""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return square_center_pad(img)
class RGB(torch.nn.Module): class RGB(torch.nn.Module):
"""Converts an image in grayscale to RGB. """Converts an image in grayscale to RGB.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment