Skip to content
Snippets Groups Projects
Commit 6bb7ac55 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved signs_to_tb model to lightning

parent ff712757
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -8,19 +8,22 @@ Simple feedforward network taking radiological signs in output and
predicting tuberculosis presence in output.
"""
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.signs_to_tb import build_signs_to_tb
from ...models.signs_to_tb import SignsToTB
# config
lr = 1e-2
# model
model = build_signs_to_tb(14, 10)
optimizer_configs = {"lr": 1e-2}
# optimizer
optimizer = Adam(model.parameters(), lr=lr)
optimizer = "Adam"
# criterion
criterion = BCEWithLogitsLoss()
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
# model
model = SignsToTB(
criterion, criterion_valid, optimizer, optimizer_configs, 14, 10
)
......@@ -2,15 +2,31 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pytorch_lightning as pl
import torch
import torch.nn as nn
class SignsToTB(nn.Module):
class SignsToTB(pl.LightningModule):
"""Radiological signs to Tuberculosis module."""
def __init__(self, input_size, hidden_size):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
input_size,
hidden_size,
):
super().__init__()
self.save_hyperparameters()
self.name = "signs_to_tb"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.input_size = input_size
self.hidden_size = hidden_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
......@@ -39,15 +55,55 @@ class SignsToTB(nn.Module):
return output
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False):
names = batch[0]
images = batch[1]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
def build_signs_to_tb(input_size, hidden_size):
"""Build SignsToTB shallow model.
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
Returns
-------
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
module : :py:class:`torch.nn.Module`
"""
model = SignsToTB(input_size, hidden_size)
model.name = "signs_to_tb"
return model
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment