From 6c981135b2fa61d1e31db4380698e23bc83fae4d Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 21 Jul 2023 20:36:51 +0200 Subject: [PATCH] [models.transforms] Implement RGB and Grayscale transforms, and hook those into our models (closes #15) --- src/ptbench/models/alexnet.py | 13 ++- src/ptbench/models/densenet.py | 8 +- src/ptbench/models/pasa.py | 4 +- src/ptbench/models/transforms.py | 135 +++++++++++++++++++++++++++++++ 4 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 src/ptbench/models/transforms.py diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index df224bc5..5daed2a3 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import RGB from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -78,10 +79,8 @@ class Alexnet(pl.LightningModule): self.name = "alexnet" self.model_transforms = [ - torchvision.transforms.Resize(512), - torchvision.transforms.ToPILImage(), - torchvision.transforms.Lambda(lambda x: x.convert("RGB")), - torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(512, antialias=True), + RGB(), ] self._train_loss = train_loss @@ -198,6 +197,12 @@ class Alexnet(pl.LightningModule): if labels.ndim == 1: labels = torch.reshape(labels, (labels.shape[0], 1)) + # debug code to inspect images by eye: + # from torchvision.transforms.functional import to_pil_image + # for k in images: + # to_pil_image(k).show() + # __import__("pdb").set_trace() + # data forwarding on the existing network outputs = self(images) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 436d73ba..3e2ee50e 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import RGB from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -75,11 +76,10 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" + # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Resize(512), - torchvision.transforms.ToPILImage(), - torchvision.transforms.Lambda(lambda x: x.convert("RGB")), - torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(512, antialias=True), + RGB(), ] self._train_loss = train_loss diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 650aa7d4..3f10757f 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -14,6 +14,7 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import Grayscale from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -72,9 +73,10 @@ class Pasa(pl.LightningModule): self.name = "pasa" + # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Grayscale(), torchvision.transforms.Resize(512, antialias=True), + Grayscale(), ] self._train_loss = train_loss diff --git a/src/ptbench/models/transforms.py b/src/ptbench/models/transforms.py new file mode 100644 index 00000000..585b5846 --- /dev/null +++ b/src/ptbench/models/transforms.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""A transform that turns grayscale images to RGB.""" + +import torch +import torch.nn +import torchvision.transforms.functional + + +def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: + """Converts an image in grayscale to RGB. + + If the image is already in RGB format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in grayscale format + (number of bands = 1), then triplicate that band 3 times (a new copy is + returned in this case). + + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor where the 3rd dimension from the last is 3. + """ + if img.ndim < 3: + raise TypeError( + f"Input image tensor should have at least 3 dimensions," + f"but found {img.ndim}" + ) + + if img.shape[-3] not in (1, 3): + raise TypeError( + f"Input image tensor should have 1 or 3 planes," + f"but found {img.shape[-3]}" + ) + + if img.shape[-3] == 3: + return img + + # it is a grayscale image - repeat the image 3 times + repetitions = img.dim() * [1] + repetitions[-3] = 3 + return img.repeat(*repetitions).to(img.device) + + +def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: + """Converts an image in RGB to grayscale. + + If the image is already in grayscale format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in RGB format, then + compresses the color planes into grayscale following this equation: + + .. math:: + + grayscale = (0.2989 * r + 0.587 * g + 0.114 * b) + + A new tensor is returned in this case. + + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor where the 3rd dimension from the last is 3. + """ + if img.ndim < 3: + raise TypeError( + f"Input image tensor should have at least 3 dimensions," + f"but found {img.ndim}" + ) + + if img.shape[-3] not in (1, 3): + raise TypeError( + f"Input image tensor should have 1 or 3 planes," + f"but found {img.shape[-3]}" + ) + + if img.shape[-3] == 1: + return img + + # it is an RGB image - use torchvision implementation + return torchvision.transforms.functional.rgb_to_grayscale(img) + + +class RGB(torch.nn.Module): + """Converts an image in grayscale to RGB. + + If the image is already in RGB format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in grayscale format + (number of bands = 1), then triplicate that band 3 times (a new copy is + returned in this case). + """ + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return grayscale_to_rgb(img) + + +class Grayscale(torch.nn.Module): + """Converts an image in RGB to grayscale. + + If the image is already in grayscale format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in RGB format, then + compresses the color planes into grayscale following this equation: + + .. math:: + + grayscale = (0.2989 * r + 0.587 * g + 0.114 * b) + + A new tensor is returned in this case. + """ + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return rgb_to_grayscale(img) -- GitLab