Skip to content
Snippets Groups Projects
logistic_regression.py 2.54 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pytorch_lightning as pl
import torch
import torch.nn as nn


class LogisticRegression(pl.LightningModule):
    """Radiological signs to Tuberculosis module."""

    def __init__(
        self,
        criterion,
        criterion_valid,
        optimizer,
        optimizer_configs,
        input_size,
    ):
        super().__init__()

        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])

        self.name = "logistic_regression"

        self.linear = nn.Linear(self.hparams.input_size, 1)

    def forward(self, x):
        """

        Parameters
        ----------

        x : list
            list of tensors.

        Returns
        -------

        tensor : :py:class:`torch.Tensor`

        """
        output = self.linear(x)

        return output

    def training_step(self, batch, batch_idx):
        images = batch[1]
        labels = batch[2]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # Forward pass on the network
        outputs = self(images)

        training_loss = self.hparams.criterion(outputs, labels.float())

        return {"loss": training_loss}

    def validation_step(self, batch, batch_idx):
        images = batch[1]
        labels = batch[2]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # data forwarding on the existing network
        outputs = self(images)
        validation_loss = self.hparams.criterion_valid(outputs, labels.float())

        return {"validation_loss": validation_loss}

    def predict_step(self, batch, batch_idx, grad_cams=False):
        names = batch[0]
        images = batch[1]

        outputs = self(images)
        probabilities = torch.sigmoid(outputs)

        # necessary check for HED architecture that uses several outputs
        # for loss calculation instead of just the last concatfuse block
        if isinstance(outputs, list):
            outputs = outputs[-1]

        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.hparams.optimizer)(
            self.parameters(), **self.hparams.optimizer_configs
        )

        return optimizer