From 63f51a74df653d539e2c5c7f9550f4fbd84a2548 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 18 Jul 2023 10:36:41 +0200 Subject: [PATCH] Added model_transforms in models --- src/ptbench/models/alexnet.py | 6 ++++++ src/ptbench/models/densenet.py | 6 ++++++ src/ptbench/models/pasa.py | 2 ++ src/ptbench/scripts/predict.py | 2 ++ src/ptbench/scripts/train.py | 1 + 5 files changed, 17 insertions(+) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 7866f36d..f809879c 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule): self.name = "alexnet" + self.model_transforms = [ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Lambda(lambda x: x.convert("RGB")), + torchvision.transforms.ToTensor(), + ] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index d1f4d03c..72637b6f 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -71,6 +71,12 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" + self.model_transforms = [ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Lambda(lambda x: x.convert("RGB")), + torchvision.transforms.ToTensor(), + ] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 4e0e281b..3202b5db 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -72,6 +72,8 @@ class Pasa(pl.LightningModule): self.name = "pasa" + self.model_transforms = [] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index f73baaba..3de96bdd 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -100,6 +100,8 @@ def predict( from ..utils.plot import relevance_analysis_plot datamodule.set_chunk_size(batch_size, 1) + datamodule.model_transforms = model.model_transforms + logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index d026e922..e331108d 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -248,6 +248,7 @@ def train( datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="fit") -- GitLab