Skip to content
Snippets Groups Projects

Make square centre-padding a model transform

Merged André Anjos requested to merge issue-23-and-39 into main
1 file
+ 3
5
Compare changes
  • Side-by-side
  • Inline
@@ -3,11 +3,44 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB."""
import numpy
import torch
import torch.nn
import torchvision.transforms.functional
def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Return a squared version of the image, centered on a canvas padded with
zeros.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
Transformed tensor, guaranteed to be square (ie. equal height and
width).
"""
height, width = img.shape[-2:]
maxdim = numpy.max([height, width])
# padding
left = (maxdim - width) // 2
top = (maxdim - height) // 2
right = maxdim - width - left
bottom = maxdim - height - top
return torchvision.transforms.functional.pad(
img, [left, top, right, bottom], 0, "constant"
)
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Convert an image in grayscale to RGB.
@@ -94,6 +127,16 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img)
class SquareCenterPad(torch.nn.Module):
"""Transform to a squared version of the image, centered on a canvas padded with zeros."""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return square_center_pad(img)
class RGB(torch.nn.Module):
"""Wrapper class around :py:func:`.grayscale_to_rgb` to be used as a model transform."""
Loading