From f6b2e274ec25118bc8016de345e4748b9ed549e8 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 18 Jul 2023 12:31:45 +0200
Subject: [PATCH] Added resize transforms in models

---
 src/ptbench/models/alexnet.py  | 1 +
 src/ptbench/models/densenet.py | 1 +
 src/ptbench/models/pasa.py     | 4 +++-
 3 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index f809879c..0b19b3d7 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule):
         self.name = "alexnet"
 
         self.model_transforms = [
+            torchvision.transforms.Resize(512),
             torchvision.transforms.ToPILImage(),
             torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
             torchvision.transforms.ToTensor(),
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 72637b6f..f6eb2cb6 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -72,6 +72,7 @@ class Densenet(pl.LightningModule):
         self.name = "densenet-121"
 
         self.model_transforms = [
+            torchvision.transforms.Resize(512),
             torchvision.transforms.ToPILImage(),
             torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
             torchvision.transforms.ToTensor(),
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 3202b5db..e2cb9b05 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -72,7 +72,9 @@ class Pasa(pl.LightningModule):
 
         self.name = "pasa"
 
-        self.model_transforms = []
+        self.model_transforms = [
+            torchvision.transforms.Resize(512),
+        ]
 
         self._train_loss = train_loss
         self._validation_loss = (
-- 
GitLab