Skip to content
Snippets Groups Projects
Commit 63f51a74 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added model_transforms in models

parent 826d392e
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule): ...@@ -73,6 +73,12 @@ class Alexnet(pl.LightningModule):
self.name = "alexnet" self.name = "alexnet"
self.model_transforms = [
torchvision.transforms.ToPILImage(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(),
]
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = (
validation_loss if validation_loss is not None else train_loss validation_loss if validation_loss is not None else train_loss
......
...@@ -71,6 +71,12 @@ class Densenet(pl.LightningModule): ...@@ -71,6 +71,12 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121" self.name = "densenet-121"
self.model_transforms = [
torchvision.transforms.ToPILImage(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(),
]
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = (
validation_loss if validation_loss is not None else train_loss validation_loss if validation_loss is not None else train_loss
......
...@@ -72,6 +72,8 @@ class Pasa(pl.LightningModule): ...@@ -72,6 +72,8 @@ class Pasa(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.model_transforms = []
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = (
validation_loss if validation_loss is not None else train_loss validation_loss if validation_loss is not None else train_loss
......
...@@ -100,6 +100,8 @@ def predict( ...@@ -100,6 +100,8 @@ def predict(
from ..utils.plot import relevance_analysis_plot from ..utils.plot import relevance_analysis_plot
datamodule.set_chunk_size(batch_size, 1) datamodule.set_chunk_size(batch_size, 1)
datamodule.model_transforms = model.model_transforms
logger.info(f"Loading checkpoint from {weight}") logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False) model = model.load_from_checkpoint(weight, strict=False)
......
...@@ -248,6 +248,7 @@ def train( ...@@ -248,6 +248,7 @@ def train(
datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="fit") datamodule.setup(stage="fit")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment