Skip to content
Snippets Groups Projects

Reviewed DataModule design+docs+types

Merged André Anjos requested to merge add-datamodule-andre into add-datamodule
3 files
+ 46
62
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -7,8 +7,6 @@ import torch
import torch.nn as nn
import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class Densenet(pl.LightningModule):
"""Densenet module.
@@ -31,7 +29,7 @@ class Densenet(pl.LightningModule):
self.name = "Densenet"
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
self.normalizer = None
# Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
@@ -55,16 +53,13 @@ class Densenet(pl.LightningModule):
imagenet weights, during contruction).
"""
if self.pretrained:
from .normalizer import TorchVisionNormalizer
from .normalizer import make_imagenet_normalizer
self.normalizer = TorchVisionNormalizer(
torch.Tensor([0.485, 0.456, 0.406]),
torch.Tensor([0.229, 0.224, 0.225]),
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import get_znorm_normalizer
from .normalizer import make_z_normalizer
self.normalizer = get_znorm_normalizer(dataloader)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[1]
Loading