-
Daniel CARRON authoredDaniel CARRON authored
logistic_regression.py 2.54 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pytorch_lightning as pl
import torch
import torch.nn as nn
class LogisticRegression(pl.LightningModule):
"""Radiological signs to Tuberculosis module."""
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
input_size,
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "logistic_regression"
self.linear = nn.Linear(self.hparams.input_size, 1)
def forward(self, x):
"""
Parameters
----------
x : list
list of tensors.
Returns
-------
tensor : :py:class:`torch.Tensor`
"""
output = self.linear(x)
return output
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False):
names = batch[0]
images = batch[1]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
return optimizer