From 63f51a74df653d539e2c5c7f9550f4fbd84a2548 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 18 Jul 2023 10:36:41 +0200
Subject: [PATCH] Added model_transforms in models

---
 src/ptbench/models/alexnet.py  | 6 ++++++
 src/ptbench/models/densenet.py | 6 ++++++
 src/ptbench/models/pasa.py     | 2 ++
 src/ptbench/scripts/predict.py | 2 ++
 src/ptbench/scripts/train.py   | 1 +
 5 files changed, 17 insertions(+)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 7866f36d..f809879c 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule):
 
         self.name = "alexnet"
 
+        self.model_transforms = [
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
+            torchvision.transforms.ToTensor(),
+        ]
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index d1f4d03c..72637b6f 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -71,6 +71,12 @@ class Densenet(pl.LightningModule):
 
         self.name = "densenet-121"
 
+        self.model_transforms = [
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
+            torchvision.transforms.ToTensor(),
+        ]
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 4e0e281b..3202b5db 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -72,6 +72,8 @@ class Pasa(pl.LightningModule):
 
         self.name = "pasa"
 
+        self.model_transforms = []
+
         self._train_loss = train_loss
         self._validation_loss = (
             validation_loss if validation_loss is not None else train_loss
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index f73baaba..3de96bdd 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -100,6 +100,8 @@ def predict(
     from ..utils.plot import relevance_analysis_plot
 
     datamodule.set_chunk_size(batch_size, 1)
+    datamodule.model_transforms = model.model_transforms
+
 
     logger.info(f"Loading checkpoint from {weight}")
     model = model.load_from_checkpoint(weight, strict=False)
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index d026e922..e331108d 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -248,6 +248,7 @@ def train(
     datamodule.drop_incomplete_batch = drop_incomplete_batch
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
 
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
-- 
GitLab