diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index df224bc5f2eb973003b87a8170b0b53fd70b5aad..5daed2a31a2834997d7d90df5a444a68da4d831f 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -14,6 +14,7 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .transforms import RGB
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -78,10 +79,8 @@ class Alexnet(pl.LightningModule):
         self.name = "alexnet"
 
         self.model_transforms = [
-            torchvision.transforms.Resize(512),
-            torchvision.transforms.ToPILImage(),
-            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
-            torchvision.transforms.ToTensor(),
+            torchvision.transforms.Resize(512, antialias=True),
+            RGB(),
         ]
 
         self._train_loss = train_loss
@@ -198,6 +197,12 @@ class Alexnet(pl.LightningModule):
         if labels.ndim == 1:
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
+        # debug code to inspect images by eye:
+        # from torchvision.transforms.functional import to_pil_image
+        # for k in images:
+        #    to_pil_image(k).show()
+        #    __import__("pdb").set_trace()
+
         # data forwarding on the existing network
         outputs = self(images)
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 436d73ba55d5829cd886b37ebee882b80c41d2d8..3e2ee50e664c9a00bb34d8ea5bee298c3fba893a 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -14,6 +14,7 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .transforms import RGB
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -75,11 +76,10 @@ class Densenet(pl.LightningModule):
 
         self.name = "densenet-121"
 
+        # image is probably large, resize first to get memory usage down
         self.model_transforms = [
-            torchvision.transforms.Resize(512),
-            torchvision.transforms.ToPILImage(),
-            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
-            torchvision.transforms.ToTensor(),
+            torchvision.transforms.Resize(512, antialias=True),
+            RGB(),
         ]
 
         self._train_loss = train_loss
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 650aa7d4f60dfe5feda914deb8518415fda8876d..3f10757fe2126b5be67ef577bccd8c8bbefe3fa5 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -14,6 +14,7 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .transforms import Grayscale
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -72,9 +73,10 @@ class Pasa(pl.LightningModule):
 
         self.name = "pasa"
 
+        # image is probably large, resize first to get memory usage down
         self.model_transforms = [
-            torchvision.transforms.Grayscale(),
             torchvision.transforms.Resize(512, antialias=True),
+            Grayscale(),
         ]
 
         self._train_loss = train_loss
diff --git a/src/ptbench/models/transforms.py b/src/ptbench/models/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..585b5846151292c69c8bb3411f68e28dc5eb97e2
--- /dev/null
+++ b/src/ptbench/models/transforms.py
@@ -0,0 +1,135 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""A transform that turns grayscale images to RGB."""
+
+import torch
+import torch.nn
+import torchvision.transforms.functional
+
+
+def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
+    """Converts an image in grayscale to RGB.
+
+    If the image is already in RGB format, then this is a NOOP - the same
+    tensor is returned (no cloning).  If the image is in grayscale format
+    (number of bands = 1), then triplicate that band 3 times (a new copy is
+    returned in this case).
+
+
+    Parameters
+    ----------
+
+    img
+        The tensor to be transformed.  Expected to be in the form: ``[...,
+        [1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
+
+    Returns
+    -------
+
+    img
+        transformed tensor where the 3rd dimension from the last is 3.
+    """
+    if img.ndim < 3:
+        raise TypeError(
+            f"Input image tensor should have at least 3 dimensions,"
+            f"but found {img.ndim}"
+        )
+
+    if img.shape[-3] not in (1, 3):
+        raise TypeError(
+            f"Input image tensor should have 1 or 3 planes,"
+            f"but found {img.shape[-3]}"
+        )
+
+    if img.shape[-3] == 3:
+        return img
+
+    # it is a grayscale image - repeat the image 3 times
+    repetitions = img.dim() * [1]
+    repetitions[-3] = 3
+    return img.repeat(*repetitions).to(img.device)
+
+
+def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
+    """Converts an image in RGB to grayscale.
+
+    If the image is already in grayscale format, then this is a NOOP - the same
+    tensor is returned (no cloning).  If the image is in RGB format, then
+    compresses the color planes into grayscale following this equation:
+
+    .. math::
+
+       grayscale = (0.2989 * r + 0.587 * g + 0.114 * b)
+
+    A new tensor is returned in this case.
+
+
+    Parameters
+    ----------
+
+    img
+        The tensor to be transformed.  Expected to be in the form: ``[...,
+        [1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
+
+    Returns
+    -------
+
+    img
+        transformed tensor where the 3rd dimension from the last is 3.
+    """
+    if img.ndim < 3:
+        raise TypeError(
+            f"Input image tensor should have at least 3 dimensions,"
+            f"but found {img.ndim}"
+        )
+
+    if img.shape[-3] not in (1, 3):
+        raise TypeError(
+            f"Input image tensor should have 1 or 3 planes,"
+            f"but found {img.shape[-3]}"
+        )
+
+    if img.shape[-3] == 1:
+        return img
+
+    # it is an RGB image - use torchvision implementation
+    return torchvision.transforms.functional.rgb_to_grayscale(img)
+
+
+class RGB(torch.nn.Module):
+    """Converts an image in grayscale to RGB.
+
+    If the image is already in RGB format, then this is a NOOP - the same
+    tensor is returned (no cloning).  If the image is in grayscale format
+    (number of bands = 1), then triplicate that band 3 times (a new copy is
+    returned in this case).
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, img: torch.Tensor) -> torch.Tensor:
+        return grayscale_to_rgb(img)
+
+
+class Grayscale(torch.nn.Module):
+    """Converts an image in RGB to grayscale.
+
+    If the image is already in grayscale format, then this is a NOOP - the same
+    tensor is returned (no cloning).  If the image is in RGB format, then
+    compresses the color planes into grayscale following this equation:
+
+    .. math::
+
+       grayscale = (0.2989 * r + 0.587 * g + 0.114 * b)
+
+    A new tensor is returned in this case.
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, img: torch.Tensor) -> torch.Tensor:
+        return rgb_to_grayscale(img)