diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 7866f36df6fb9fe991e13ae6012c71df1c917777..f809879cac4804da9614cac846e83da8d646ccdb 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule): self.name = "alexnet" + self.model_transforms = [ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Lambda(lambda x: x.convert("RGB")), + torchvision.transforms.ToTensor(), + ] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index d1f4d03ccb21c4d5bf6d6df60ab717a5a34078fd..72637b6fca1a6357de2573479039cf468ec41efa 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -71,6 +71,12 @@ class Densenet(pl.LightningModule): self.name = "densenet-121" + self.model_transforms = [ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Lambda(lambda x: x.convert("RGB")), + torchvision.transforms.ToTensor(), + ] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 4e0e281b7429b782faaffb09d2f22fcdc49f61c0..3202b5dbf045d683f4b0769464744f7b4227c236 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -72,6 +72,8 @@ class Pasa(pl.LightningModule): self.name = "pasa" + self.model_transforms = [] + self._train_loss = train_loss self._validation_loss = ( validation_loss if validation_loss is not None else train_loss diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index f73baabab5d2f09f5d2fd6d802c429ec63a2cc6f..3de96bdd77aea7390df6f87d14bd835861049d43 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -100,6 +100,8 @@ def predict( from ..utils.plot import relevance_analysis_plot datamodule.set_chunk_size(batch_size, 1) + datamodule.model_transforms = model.model_transforms + logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index d026e92236a21c7663fe884bd49b73079b7b55d2..e331108db6ec48aaf42fd3593bc0b0abb379849a 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -248,6 +248,7 @@ def train( datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="fit")