Skip to content
Snippets Groups Projects
Commit f6b2e274 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added resize transforms in models

parent 63f51a74
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule): ...@@ -74,6 +74,7 @@ class Alexnet(pl.LightningModule):
self.name = "alexnet" self.name = "alexnet"
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512),
torchvision.transforms.ToPILImage(), torchvision.transforms.ToPILImage(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")), torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
......
...@@ -72,6 +72,7 @@ class Densenet(pl.LightningModule): ...@@ -72,6 +72,7 @@ class Densenet(pl.LightningModule):
self.name = "densenet-121" self.name = "densenet-121"
self.model_transforms = [ self.model_transforms = [
torchvision.transforms.Resize(512),
torchvision.transforms.ToPILImage(), torchvision.transforms.ToPILImage(),
torchvision.transforms.Lambda(lambda x: x.convert("RGB")), torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
......
...@@ -72,7 +72,9 @@ class Pasa(pl.LightningModule): ...@@ -72,7 +72,9 @@ class Pasa(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.model_transforms = [] self.model_transforms = [
torchvision.transforms.Resize(512),
]
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment