From 6290c1bd893c1832ecb2c0e50950b2ece27fb4e6 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 08:18:18 +0200 Subject: [PATCH] Moved densenet to lightning --- src/ptbench/configs/models/densenet.py | 19 +++--- src/ptbench/models/densenet.py | 82 +++++++++++++++++++------- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 2017786a..67594908 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -4,19 +4,22 @@ """DenseNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.densenet import build_densenet +from ...models.densenet import Densenet # config -lr = 0.0001 - -# model -model = build_densenet(pretrained=False) +optimizer_configs = {"lr": 0.0001} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Densenet( + criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False +) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 7a98acac..33476d42 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -2,23 +2,40 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import pytorch_lightning as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class Densenet(nn.Module): +class Densenet(pl.LightningModule): """Densenet module. Note: only usable with a normalized dataset """ - def __init__(self, pretrained=False): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_params, + pretrained=False, + nb_channels=3, + ): super().__init__() + self.save_hyperparameters() + + self.name = "Densenet" + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) + # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT self.model_ft = models.densenet121(weights=weights) @@ -43,23 +60,48 @@ class Densenet(nn.Module): tensor : :py:class:`torch.Tensor` """ - return self.model_ft(x) + x = self.normalizer(x) -def build_densenet(pretrained=False, nb_channels=3): - """Build Densenet CNN. + x = self.model_ft(x) - Returns - ------- + return x - module : :py:class:`torch.nn.Module` - """ - model = Densenet(pretrained=pretrained) - model = [ - ("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), - ("model", model), - ] - model = nn.Sequential(OrderedDict(model)) - - model.name = "Densenet" - return model + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(images) + + training_loss = self.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_params + ) + + return optimizer -- GitLab