diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 7866f36df6fb9fe991e13ae6012c71df1c917777..f809879cac4804da9614cac846e83da8d646ccdb 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule):
 
         self.name = "alexnet"
 
+        self.model_transforms = [
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
+            torchvision.transforms.ToTensor(),
+        ]
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index d1f4d03ccb21c4d5bf6d6df60ab717a5a34078fd..72637b6fca1a6357de2573479039cf468ec41efa 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -71,6 +71,12 @@ class Densenet(pl.LightningModule):
 
         self.name = "densenet-121"
 
+        self.model_transforms = [
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
+            torchvision.transforms.ToTensor(),
+        ]
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 4e0e281b7429b782faaffb09d2f22fcdc49f61c0..3202b5dbf045d683f4b0769464744f7b4227c236 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -72,6 +72,8 @@ class Pasa(pl.LightningModule):
 
         self.name = "pasa"
 
+        self.model_transforms = []
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index f73baabab5d2f09f5d2fd6d802c429ec63a2cc6f..3de96bdd77aea7390df6f87d14bd835861049d43 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -100,6 +100,8 @@ def predict(
     from ..utils.plot import relevance_analysis_plot
 
     datamodule.set_chunk_size(batch_size, 1)
+    datamodule.model_transforms = model.model_transforms
+
 
     logger.info(f"Loading checkpoint from {weight}")
     model = model.load_from_checkpoint(weight, strict=False)
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index d026e92236a21c7663fe884bd49b73079b7b55d2..e331108db6ec48aaf42fd3593bc0b0abb379849a 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -248,6 +248,7 @@ def train(
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
 
     datamodule.prepare_data()
     datamodule.setup(stage="fit")