diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index df224bc5f2eb973003b87a8170b0b53fd70b5aad..5daed2a31a2834997d7d90df5a444a68da4d831f 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import RGB from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -78,10 +79,8 @@ class Alexnet(pl.LightningModule): self.name = "alexnet" self.model_transforms = [ - torchvision.transforms.Resize(512), - torchvision.transforms.ToPILImage(), - torchvision.transforms.Lambda(lambda x: x.convert("RGB")), - torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(512, antialias=True), + RGB(), ] self._train_loss = train_loss @@ -198,6 +197,12 @@ class Alexnet(pl.LightningModule): if labels.ndim == 1: labels = torch.reshape(labels, (labels.shape[0], 1)) + # debug code to inspect images by eye: + # from torchvision.transforms.functional import to_pil_image + # for k in images: + # to_pil_image(k).show() + # __import__("pdb").set_trace() + # data forwarding on the existing network outputs = self(images) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 436d73ba55d5829cd886b37ebee882b80c41d2d8..3e2ee50e664c9a00bb34d8ea5bee298c3fba893a 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import RGB from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -75,11 +76,10 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" + # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Resize(512), - torchvision.transforms.ToPILImage(), - torchvision.transforms.Lambda(lambda x: x.convert("RGB")), - torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(512, antialias=True), + RGB(), ] self._train_loss = train_loss diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 650aa7d4f60dfe5feda914deb8518415fda8876d..3f10757fe2126b5be67ef577bccd8c8bbefe3fa5 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -14,6 +14,7 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .transforms import Grayscale from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -72,9 +73,10 @@ class Pasa(pl.LightningModule): self.name = "pasa" + # image is probably large, resize first to get memory usage down self.model_transforms = [ - torchvision.transforms.Grayscale(), torchvision.transforms.Resize(512, antialias=True), + Grayscale(), ] self._train_loss = train_loss diff --git a/src/ptbench/models/transforms.py b/src/ptbench/models/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..585b5846151292c69c8bb3411f68e28dc5eb97e2 --- /dev/null +++ b/src/ptbench/models/transforms.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""A transform that turns grayscale images to RGB.""" + +import torch +import torch.nn +import torchvision.transforms.functional + + +def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor: + """Converts an image in grayscale to RGB. + + If the image is already in RGB format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in grayscale format + (number of bands = 1), then triplicate that band 3 times (a new copy is + returned in this case). + + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor where the 3rd dimension from the last is 3. + """ + if img.ndim < 3: + raise TypeError( + f"Input image tensor should have at least 3 dimensions," + f"but found {img.ndim}" + ) + + if img.shape[-3] not in (1, 3): + raise TypeError( + f"Input image tensor should have 1 or 3 planes," + f"but found {img.shape[-3]}" + ) + + if img.shape[-3] == 3: + return img + + # it is a grayscale image - repeat the image 3 times + repetitions = img.dim() * [1] + repetitions[-3] = 3 + return img.repeat(*repetitions).to(img.device) + + +def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: + """Converts an image in RGB to grayscale. + + If the image is already in grayscale format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in RGB format, then + compresses the color planes into grayscale following this equation: + + .. math:: + + grayscale = (0.2989 * r + 0.587 * g + 0.114 * b) + + A new tensor is returned in this case. + + + Parameters + ---------- + + img + The tensor to be transformed. Expected to be in the form: ``[..., + [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). + + Returns + ------- + + img + transformed tensor where the 3rd dimension from the last is 3. + """ + if img.ndim < 3: + raise TypeError( + f"Input image tensor should have at least 3 dimensions," + f"but found {img.ndim}" + ) + + if img.shape[-3] not in (1, 3): + raise TypeError( + f"Input image tensor should have 1 or 3 planes," + f"but found {img.shape[-3]}" + ) + + if img.shape[-3] == 1: + return img + + # it is an RGB image - use torchvision implementation + return torchvision.transforms.functional.rgb_to_grayscale(img) + + +class RGB(torch.nn.Module): + """Converts an image in grayscale to RGB. + + If the image is already in RGB format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in grayscale format + (number of bands = 1), then triplicate that band 3 times (a new copy is + returned in this case). + """ + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return grayscale_to_rgb(img) + + +class Grayscale(torch.nn.Module): + """Converts an image in RGB to grayscale. + + If the image is already in grayscale format, then this is a NOOP - the same + tensor is returned (no cloning). If the image is in RGB format, then + compresses the color planes into grayscale following this equation: + + .. math:: + + grayscale = (0.2989 * r + 0.587 * g + 0.114 * b) + + A new tensor is returned in this case. + """ + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + return rgb_to_grayscale(img)