Skip to content
Snippets Groups Projects
Commit 6c981135 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.transforms] Implement RGB and Grayscale transforms, and hook those...

[models.transforms] Implement RGB and Grayscale transforms, and hook those into our models (closes #15)
parent 8b2a2902
No related branches found
No related tags found
No related merge requests found
Pipeline #76321 failed
...@@ -14,6 +14,7 @@ import torchvision.models as models ...@@ -14,6 +14,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .transforms import RGB
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -78,10 +79,8 @@ class Alexnet(pl.LightningModule): ...@@ -78,10 +79,8 @@ class Alexnet(pl.LightningModule):
self.name = "alexnet" self.name = "alexnet"
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512), torchvision.transforms.Resize(512, antialias=True),
torchvision.transforms.ToPILImage(), RGB(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(),
] ]
self._train_loss = train_loss self._train_loss = train_loss
...@@ -198,6 +197,12 @@ class Alexnet(pl.LightningModule): ...@@ -198,6 +197,12 @@ class Alexnet(pl.LightningModule):
if labels.ndim == 1: if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# debug code to inspect images by eye:
# from torchvision.transforms.functional import to_pil_image
# for k in images:
# to_pil_image(k).show()
# __import__("pdb").set_trace()
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
......
...@@ -14,6 +14,7 @@ import torchvision.models as models ...@@ -14,6 +14,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .transforms import RGB
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -75,11 +76,10 @@ class Densenet(pl.LightningModule): ...@@ -75,11 +76,10 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121" self.name = "densenet-121"
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512), torchvision.transforms.Resize(512, antialias=True),
torchvision.transforms.ToPILImage(), RGB(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(),
] ]
self._train_loss = train_loss self._train_loss = train_loss
......
...@@ -14,6 +14,7 @@ import torch.utils.data ...@@ -14,6 +14,7 @@ import torch.utils.data
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .transforms import Grayscale
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -72,9 +73,10 @@ class Pasa(pl.LightningModule): ...@@ -72,9 +73,10 @@ class Pasa(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
# image is probably large, resize first to get memory usage down
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Grayscale(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
Grayscale(),
] ]
self._train_loss = train_loss self._train_loss = train_loss
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""A transform that turns grayscale images to RGB."""
import torch
import torch.nn
import torchvision.transforms.functional
def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
"""Converts an image in grayscale to RGB.
If the image is already in RGB format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in grayscale format
(number of bands = 1), then triplicate that band 3 times (a new copy is
returned in this case).
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
img
transformed tensor where the 3rd dimension from the last is 3.
"""
if img.ndim < 3:
raise TypeError(
f"Input image tensor should have at least 3 dimensions,"
f"but found {img.ndim}"
)
if img.shape[-3] not in (1, 3):
raise TypeError(
f"Input image tensor should have 1 or 3 planes,"
f"but found {img.shape[-3]}"
)
if img.shape[-3] == 3:
return img
# it is a grayscale image - repeat the image 3 times
repetitions = img.dim() * [1]
repetitions[-3] = 3
return img.repeat(*repetitions).to(img.device)
def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
"""Converts an image in RGB to grayscale.
If the image is already in grayscale format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in RGB format, then
compresses the color planes into grayscale following this equation:
.. math::
grayscale = (0.2989 * r + 0.587 * g + 0.114 * b)
A new tensor is returned in this case.
Parameters
----------
img
The tensor to be transformed. Expected to be in the form: ``[...,
[1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
Returns
-------
img
transformed tensor where the 3rd dimension from the last is 3.
"""
if img.ndim < 3:
raise TypeError(
f"Input image tensor should have at least 3 dimensions,"
f"but found {img.ndim}"
)
if img.shape[-3] not in (1, 3):
raise TypeError(
f"Input image tensor should have 1 or 3 planes,"
f"but found {img.shape[-3]}"
)
if img.shape[-3] == 1:
return img
# it is an RGB image - use torchvision implementation
return torchvision.transforms.functional.rgb_to_grayscale(img)
class RGB(torch.nn.Module):
"""Converts an image in grayscale to RGB.
If the image is already in RGB format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in grayscale format
(number of bands = 1), then triplicate that band 3 times (a new copy is
returned in this case).
"""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return grayscale_to_rgb(img)
class Grayscale(torch.nn.Module):
"""Converts an image in RGB to grayscale.
If the image is already in grayscale format, then this is a NOOP - the same
tensor is returned (no cloning). If the image is in RGB format, then
compresses the color planes into grayscale following this equation:
.. math::
grayscale = (0.2989 * r + 0.587 * g + 0.114 * b)
A new tensor is returned in this case.
"""
def __init__(self):
super().__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
return rgb_to_grayscale(img)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment