From e7446664feea65c2fe843d6a57d6b5a6f6313ba2 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 14:35:37 +0200 Subject: [PATCH] Moved logistic_regression model to lightning --- .../configs/models/logistic_regression.py | 21 ++--- src/ptbench/models/logistic_regression.py | 79 ++++++++++++++++--- 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/src/ptbench/configs/models/logistic_regression.py b/src/ptbench/configs/models/logistic_regression.py index b93935b4..145dddd7 100644 --- a/src/ptbench/configs/models/logistic_regression.py +++ b/src/ptbench/configs/models/logistic_regression.py @@ -7,20 +7,23 @@ Simple feedforward network taking radiological signs in output and predicting tuberculosis presence in output. """ - +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam -from ...models.logistic_regression import build_logistic_regression +from ...models.logistic_regression import LogisticRegression # config -lr = 1e-2 - -# model -model = build_logistic_regression(14) +optimizer_configs = {"lr": 1e-2} +input_size = 14 # optimizer -optimizer = Adam(model.parameters(), lr=lr) +optimizer = "Adam" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = LogisticRegression( + criterion, criterion_valid, optimizer, optimizer_configs, input_size +) diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 7e7818c7..684155b4 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -2,16 +2,32 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import pytorch_lightning as pl import torch import torch.nn as nn -class LogisticRegression(nn.Module): +class LogisticRegression(pl.LightningModule): """Radiological signs to Tuberculosis module.""" - def __init__(self, input_size): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + input_size, + ): super().__init__() - self.linear = torch.nn.Linear(input_size, 1) + + self.save_hyperparameters() + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.name = "logistic_regression" + + self.linear = nn.Linear(input_size, 1) def forward(self, x): """ @@ -32,15 +48,54 @@ class LogisticRegression(nn.Module): return output + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(images) + + training_loss = self.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] -def build_logistic_regression(input_size): - """Build logistic regression module. + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) - Returns - ------- + def configure_optimizers(self): + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - module : :py:class:`torch.nn.Module` - """ - model = LogisticRegression(input_size) - model.name = "logistic_regression" - return model + return optimizer -- GitLab