# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import pytorch_lightning as pl import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer class Densenet(pl.LightningModule): """Densenet module. Note: only usable with a normalized dataset """ def __init__( self, criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False, nb_channels=3, ): super().__init__() self.save_hyperparameters() self.name = "Densenet" self.criterion = criterion self.criterion_valid = criterion_valid self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT self.model_ft = models.densenet121(weights=weights) # Adapt output features self.model_ft.classifier = nn.Sequential( nn.Linear(1024, 256), nn.Linear(256, 1) ) def forward(self, x): """ Parameters ---------- x : list list of tensors. Returns ------- tensor : :py:class:`torch.Tensor` """ x = self.normalizer(x) x = self.model_ft(x) return x def training_step(self, batch, batch_idx): images = batch[1] labels = batch[2] # Increase label dimension if too low # Allows single and multiclass usage if labels.ndim == 1: labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network outputs = self(images) training_loss = self.criterion(outputs, labels.double()) return {"loss": training_loss} def validation_step(self, batch, batch_idx): images = batch[1] labels = batch[2] # Increase label dimension if too low # Allows single and multiclass usage if labels.ndim == 1: labels = torch.reshape(labels, (labels.shape[0], 1)) # data forwarding on the existing network outputs = self(images) validation_loss = self.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} def predict_step(self, batch, batch_idx, grad_cams=False): names = batch[0] images = batch[1] outputs = self(images) probabilities = torch.sigmoid(outputs) # necessary check for HED architecture that uses several outputs # for loss calculation instead of just the last concatfuse block if isinstance(outputs, list): outputs = outputs[-1] return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs optimizer = getattr(torch.optim, self.hparams.optimizer)( self.parameters(), **self.hparams.optimizer_params ) return optimizer