From 6bb7ac555756a8f36bc2db88b679fa31c6423686 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 14:32:09 +0200 Subject: [PATCH] Moved signs_to_tb model to lightning --- src/ptbench/configs/models/signs_to_tb.py | 19 +++--- src/ptbench/models/signs_to_tb.py | 80 +++++++++++++++++++---- 2 files changed, 79 insertions(+), 20 deletions(-) diff --git a/src/ptbench/configs/models/signs_to_tb.py b/src/ptbench/configs/models/signs_to_tb.py index 3bd552da..1ce89b12 100644 --- a/src/ptbench/configs/models/signs_to_tb.py +++ b/src/ptbench/configs/models/signs_to_tb.py @@ -8,19 +8,22 @@ Simple feedforward network taking radiological signs in output and predicting tuberculosis presence in output. """ +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.signs_to_tb import build_signs_to_tb +from ...models.signs_to_tb import SignsToTB # config -lr = 1e-2 - -# model -model = build_signs_to_tb(14, 10) +optimizer_configs = {"lr": 1e-2} # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = SignsToTB( + criterion, criterion_valid, optimizer, optimizer_configs, 14, 10 +) diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index f3b3d5ea..653b590a 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -2,15 +2,31 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import pytorch_lightning as pl import torch -import torch.nn as nn -class SignsToTB(nn.Module): +class SignsToTB(pl.LightningModule): """Radiological signs to Tuberculosis module.""" - def __init__(self, input_size, hidden_size): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + input_size, + hidden_size, + ): super().__init__() + + self.save_hyperparameters() + + self.name = "signs_to_tb" + + self.criterion = criterion + self.criterion_valid = criterion_valid + self.input_size = input_size self.hidden_size = hidden_size self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) @@ -39,15 +55,55 @@ class SignsToTB(nn.Module): return output + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(images) + + training_loss = self.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] -def build_signs_to_tb(input_size, hidden_size): - """Build SignsToTB shallow model. + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) - Returns - ------- + def configure_optimizers(self): + # Dynamically instantiates the optimizer given the configs + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - module : :py:class:`torch.nn.Module` - """ - model = SignsToTB(input_size, hidden_size) - model.name = "signs_to_tb" - return model + return optimizer -- GitLab