Skip to content
Snippets Groups Projects
Commit d20d2311 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved densenet_rs to lightning

parent 029a57a9
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -7,10 +7,10 @@
A Densenet121 model for radiological extraction
"""
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet_rs import build_densenetrs
from ...models.densenet_rs import DensenetRS
# Import the default protocol if none is available
if "dataset" not in locals():
......@@ -19,16 +19,14 @@ if "dataset" not in locals():
dataset = default.dataset
# config
lr = 1e-4
# model
model = build_densenetrs()
optimizer_configs = {"lr": 1e-4}
# optimizer
optimizer = Adam(
filter(lambda p: p.requires_grad, model.model.model_ft.parameters()), lr=lr
)
optimizer = "Adam"
# criterion
criterion = BCEWithLogitsLoss()
criterion_valid = BCEWithLogitsLoss()
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
# model
model = DensenetRS(criterion, criterion_valid, optimizer, optimizer_configs)
......@@ -2,20 +2,35 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from collections import OrderedDict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class DensenetRS(nn.Module):
class DensenetRS(pl.LightningModule):
"""Densenet121 module for radiological extraction."""
def __init__(self):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "DensenetRS"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer()
# Load pretrained model
self.model_ft = models.densenet121(
weights=models.DenseNet121_Weights.DEFAULT
......@@ -40,20 +55,61 @@ class DensenetRS(nn.Module):
tensor : :py:class:`torch.Tensor`
"""
return self.model_ft(x)
x = self.normalizer(x)
x = self.model_ft(x)
return x
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
def build_densenetrs():
"""Build DensenetRS CNN.
return {"loss": training_loss}
Returns
-------
def validation_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
module : :py:class:`torch.nn.Module`
"""
model = DensenetRS()
model = [("normalizer", TorchVisionNormalizer()), ("model", model)]
model = nn.Sequential(OrderedDict(model))
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False):
names = batch[0]
images = batch[1]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
filter(lambda p: p.requires_grad, self.model_ft.parameters()),
**self.hparams.optimizer_configs,
)
model.name = "DensenetRS"
return model
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment