From f6b2e274ec25118bc8016de345e4748b9ed549e8 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 18 Jul 2023 12:31:45 +0200 Subject: [PATCH] Added resize transforms in models --- src/ptbench/models/alexnet.py | 1 + src/ptbench/models/densenet.py | 1 + src/ptbench/models/pasa.py | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index f809879c..0b19b3d7 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule): self.name = "alexnet" self.model_transforms = [ + torchvision.transforms.Resize(512), torchvision.transforms.ToPILImage(), torchvision.transforms.Lambda(lambda x: x.convert("RGB")), torchvision.transforms.ToTensor(), diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 72637b6f..f6eb2cb6 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -72,6 +72,7 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" self.model_transforms = [ + torchvision.transforms.Resize(512), torchvision.transforms.ToPILImage(), torchvision.transforms.Lambda(lambda x: x.convert("RGB")), torchvision.transforms.ToTensor(), diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 3202b5db..e2cb9b05 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -72,7 +72,9 @@ class Pasa(pl.LightningModule): self.name = "pasa" - self.model_transforms = [] + self.model_transforms = [ + torchvision.transforms.Resize(512), + ] self._train_loss = train_loss self._validation_loss = ( -- GitLab