diff --git a/src/ptbench/configs/models_datasets/densenet_rs.py b/src/ptbench/configs/models_datasets/densenet_rs.py index 57fc7b78674779d50fe16002dc154b0dbe13e33a..714404bf173c8868e4c67e49f257a1b221e540d7 100644 --- a/src/ptbench/configs/models_datasets/densenet_rs.py +++ b/src/ptbench/configs/models_datasets/densenet_rs.py @@ -7,10 +7,10 @@ A Densenet121 model for radiological extraction """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.densenet_rs import build_densenetrs +from ...models.densenet_rs import DensenetRS # Import the default protocol if none is available if "dataset" not in locals(): @@ -19,16 +19,14 @@ if "dataset" not in locals(): dataset = default.dataset # config -lr = 1e-4 - -# model -model = build_densenetrs() +optimizer_configs = {"lr": 1e-4} # optimizer -optimizer = Adam( - filter(lambda p: p.requires_grad, model.model.model_ft.parameters()), lr=lr -) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() -criterion_valid = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = DensenetRS(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index c4448fbca8d97a60ccc041cc847bd79a0fd58056..997516a02bcdb5f2b7fdbe04e10fd48077d51092 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -2,20 +2,35 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import pytorch_lightning as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class DensenetRS(nn.Module): +class DensenetRS(pl.LightningModule): """Densenet121 module for radiological extraction.""" - def __init__(self): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + ): super().__init__() + self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) + + self.name = "DensenetRS" + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.normalizer = TorchVisionNormalizer() + # Load pretrained model self.model_ft = models.densenet121( weights=models.DenseNet121_Weights.DEFAULT @@ -40,20 +55,61 @@ class DensenetRS(nn.Module): tensor : :py:class:`torch.Tensor` """ - return self.model_ft(x) + x = self.normalizer(x) + x = self.model_ft(x) + return x + + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(images) + + training_loss = self.criterion(outputs, labels.double()) -def build_densenetrs(): - """Build DensenetRS CNN. + return {"loss": training_loss} - Returns - ------- + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - module : :py:class:`torch.nn.Module` - """ - model = DensenetRS() - model = [("normalizer", TorchVisionNormalizer()), ("model", model)] - model = nn.Sequential(OrderedDict(model)) + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + filter(lambda p: p.requires_grad, self.model_ft.parameters()), + **self.hparams.optimizer_configs, + ) - model.name = "DensenetRS" - return model + return optimizer